diff --git a/Cargo.lock b/Cargo.lock index d900b67a3..596756b97 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3750,8 +3750,11 @@ dependencies = [ name = "openshell-policy" version = "0.0.0" dependencies = [ + "glob", "miette", "openshell-core", + "openshell-supervisor-middleware", + "prost-types", "serde", "serde_json", "serde_yml", @@ -3811,6 +3814,7 @@ dependencies = [ "openshell-core", "openshell-ocsf", "openshell-policy", + "openshell-supervisor-middleware", "openshell-supervisor-network", "openshell-supervisor-process", "rustls", @@ -3865,6 +3869,7 @@ dependencies = [ "openshell-providers", "openshell-router", "openshell-server-macros", + "openshell-supervisor-middleware", "petname", "pin-project-lite", "prost", @@ -3907,6 +3912,19 @@ dependencies = [ "syn 2.0.117", ] +[[package]] +name = "openshell-supervisor-middleware" +version = "0.0.0" +dependencies = [ + "miette", + "openshell-core", + "prost-types", + "regex", + "tokio", + "tokio-stream", + "tonic", +] + [[package]] name = "openshell-supervisor-network" version = "0.0.0" @@ -3929,6 +3947,8 @@ dependencies = [ "openshell-ocsf", "openshell-policy", "openshell-router", + "openshell-supervisor-middleware", + "prost-types", "rcgen", "regorus", "reqwest 0.12.28", @@ -3948,6 +3968,7 @@ dependencies = [ "tokio-tungstenite 0.26.2", "tower-mcp-types", "tracing", + "tracing-subscriber", "uuid", "webpki-roots 1.0.7", ] diff --git a/Cargo.toml b/Cargo.toml index f450cd5c8..fd3641d68 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -74,6 +74,7 @@ serde_yml = "0.0.12" toml = "0.8" apollo-parser = "0.8.5" tower-mcp-types = "0.12.0" +regex = "1" # HTTP client reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls-native-roots"] } diff --git a/architecture/gateway.md b/architecture/gateway.md index d873b2a10..ba8437b7f 100644 --- a/architecture/gateway.md +++ b/architecture/gateway.md @@ -271,6 +271,14 @@ config path. A gateway-global policy can override sandbox-scoped policy. The sandbox supervisor polls for config revisions and hot-reloads dynamic policy when the policy engine accepts the update. +External supervisor middleware registration is operator-owned gateway +configuration. At startup the gateway connects to each service, validates its +described bindings and operator body limit, and rejects duplicate binding IDs. +Before persisting a policy, the gateway asks each selected implementation to +validate its config. The effective sandbox config contains only the registered +services required by that policy; supervisors invoke those services directly on +the request path. + Provider credential expiry is enforced during gateway-to-sandbox credential resolution and again by the sandbox placeholder resolver. This keeps expired credentials from resolving even when a running sandbox still has retained diff --git a/architecture/sandbox.md b/architecture/sandbox.md index 31a357abf..a97e81322 100644 --- a/architecture/sandbox.md +++ b/architecture/sandbox.md @@ -61,6 +61,14 @@ matchers; generic JSON-RPC rules match only the method. JSON-RPC responses and server-to-client MCP messages on response or SSE streams are relayed but are not currently parsed for policy enforcement. +For admitted HTTP requests, the proxy can run an ordered supervisor middleware +chain before credential injection. Host selectors choose the chain independently +of the network rule that admitted the request. Built-ins run in-process; +operator-registered services are called directly from the supervisor +over the common middleware gRPC contract. The gateway validates external +service capabilities and policy-owned config before delivery. Supervisors keep +the last-known-good service registry when a live config reload fails. + `https://inference.local` is special. It bypasses OPA network policy and is handled by the inference interception path: @@ -169,6 +177,8 @@ quickly. - If gateway config polling fails, the sandbox keeps its last-known-good policy. - If a live policy update is invalid, the supervisor rejects it and keeps the current policy. +- If an operator-run middleware call fails, the selected config's `on_error` + behavior decides whether to deny the request or continue without that stage. - Existing raw byte streams are connection scoped. Dynamic policy changes apply to new connections or the next parsed HTTP request where the proxy can safely re-evaluate. diff --git a/crates/openshell-core/src/grpc_client.rs b/crates/openshell-core/src/grpc_client.rs index 96158a1d1..836c7880c 100644 --- a/crates/openshell-core/src/grpc_client.rs +++ b/crates/openshell-core/src/grpc_client.rs @@ -24,11 +24,11 @@ use std::time::{Duration, SystemTime, UNIX_EPOCH}; use crate::proto::{ DenialSummary, GetDraftPolicyRequest, GetInferenceBundleRequest, GetInferenceBundleResponse, - GetSandboxConfigRequest, GetSandboxProviderEnvironmentRequest, IssueSandboxTokenRequest, - NetworkActivitySummary, PolicyChunk, PolicySource, PolicyStatus, RefreshSandboxTokenRequest, - ReportPolicyStatusRequest, SandboxPolicy as ProtoSandboxPolicy, SubmitPolicyAnalysisRequest, - SubmitPolicyAnalysisResponse, UpdateConfigRequest, inference_client::InferenceClient, - open_shell_client::OpenShellClient, + GetSandboxConfigRequest, GetSandboxConfigResponse, GetSandboxProviderEnvironmentRequest, + IssueSandboxTokenRequest, NetworkActivitySummary, PolicyChunk, PolicySource, PolicyStatus, + RefreshSandboxTokenRequest, ReportPolicyStatusRequest, SandboxPolicy as ProtoSandboxPolicy, + SubmitPolicyAnalysisRequest, SubmitPolicyAnalysisResponse, UpdateConfigRequest, + inference_client::InferenceClient, open_shell_client::OpenShellClient, }; use crate::sandbox_env; use miette::{IntoDiagnostic, Result, WrapErr}; @@ -573,19 +573,36 @@ pub async fn fetch_policy(endpoint: &str, sandbox_id: &str) -> Result Result { + debug!(endpoint = %endpoint, sandbox_id = %sandbox_id, "Connecting to OpenShell server"); + let mut client = connect(endpoint).await?; + fetch_sandbox_config_with_client(&mut client, sandbox_id).await +} + +async fn fetch_sandbox_config_with_client( client: &mut OpenShellClient, sandbox_id: &str, -) -> Result> { - let response = client +) -> Result { + client .get_sandbox_config(GetSandboxConfigRequest { sandbox_id: sandbox_id.to_string(), }) .await - .into_diagnostic()?; + .map(tonic::Response::into_inner) + .into_diagnostic() +} - let inner = response.into_inner(); +/// Fetch sandbox policy using an existing client connection. +async fn fetch_policy_with_client( + client: &mut OpenShellClient, + sandbox_id: &str, +) -> Result> { + let inner = fetch_sandbox_config_with_client(client, sandbox_id).await?; // version 0 with no policy means the sandbox was created without one. if inner.version == 0 && inner.policy.is_none() { @@ -711,6 +728,7 @@ pub struct SettingsPollResult { /// When `policy_source` is `Global`, the version of the global policy revision. pub global_policy_version: u32, pub provider_env_revision: u64, + pub supervisor_middleware_services: Vec, } pub struct ProviderEnvironmentResult { @@ -755,6 +773,7 @@ impl CachedOpenShellClient { settings: inner.settings, global_policy_version: inner.global_policy_version, provider_env_revision: inner.provider_env_revision, + supervisor_middleware_services: inner.supervisor_middleware_services, }) } diff --git a/crates/openshell-core/src/proto/mod.rs b/crates/openshell-core/src/proto/mod.rs index 08b062d2e..1ac6fc94c 100644 --- a/crates/openshell-core/src/proto/mod.rs +++ b/crates/openshell-core/src/proto/mod.rs @@ -79,8 +79,22 @@ pub mod inference { } } +#[allow( + clippy::all, + clippy::pedantic, + clippy::nursery, + unused_qualifications, + rust_2018_idioms +)] +pub mod middleware { + pub mod v1 { + include!(concat!(env!("OUT_DIR"), "/openshell.middleware.v1.rs")); + } +} + pub use datamodel::v1::*; pub use inference::v1::*; +pub use middleware::v1::*; pub use openshell::*; pub use sandbox::v1::*; pub use test::ObjectForTest; diff --git a/crates/openshell-policy/Cargo.toml b/crates/openshell-policy/Cargo.toml index 16719de13..073728db1 100644 --- a/crates/openshell-policy/Cargo.toml +++ b/crates/openshell-policy/Cargo.toml @@ -11,7 +11,10 @@ license.workspace = true repository.workspace = true [dependencies] +glob = { workspace = true } openshell-core = { path = "../openshell-core", default-features = false } +openshell-supervisor-middleware = { path = "../openshell-supervisor-middleware" } +prost-types = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } serde_yml = { workspace = true } diff --git a/crates/openshell-policy/src/lib.rs b/crates/openshell-policy/src/lib.rs index 9d5dc5b25..c301af01a 100644 --- a/crates/openshell-policy/src/lib.rs +++ b/crates/openshell-policy/src/lib.rs @@ -12,15 +12,15 @@ mod compose; mod merge; -use std::collections::{BTreeMap, HashMap}; +use std::collections::{BTreeMap, HashMap, HashSet}; use std::fmt; use std::path::Path; use miette::{IntoDiagnostic, Result, WrapErr}; use openshell_core::proto::{ FilesystemPolicy, GraphqlOperation, L7Allow, L7DenyRule, L7QueryMatcher, L7Rule, - LandlockPolicy, McpOptions, NetworkBinary, NetworkEndpoint, NetworkPolicyRule, ProcessPolicy, - SandboxPolicy, + LandlockPolicy, McpOptions, MiddlewareEndpointSelector, NetworkBinary, NetworkEndpoint, + NetworkMiddlewareConfig, NetworkPolicyRule, ProcessPolicy, SandboxPolicy, }; use serde::{Deserialize, Serialize}; @@ -49,6 +49,8 @@ struct PolicyFile { process: Option, #[serde(default, skip_serializing_if = "BTreeMap::is_empty")] network_policies: BTreeMap, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + network_middlewares: Vec, } #[derive(Debug, Serialize, Deserialize)] @@ -89,6 +91,28 @@ struct NetworkPolicyRuleDef { binaries: Vec, } +#[derive(Debug, Serialize, Deserialize)] +#[serde(deny_unknown_fields)] +struct NetworkMiddlewareConfigDef { + name: String, + middleware: String, + #[serde(default, skip_serializing_if = "BTreeMap::is_empty")] + config: BTreeMap, + #[serde(default, skip_serializing_if = "String::is_empty")] + on_error: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + endpoints: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(deny_unknown_fields)] +struct MiddlewareEndpointSelectorDef { + #[serde(default, skip_serializing_if = "Vec::is_empty")] + include: Vec, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + exclude: Vec, +} + #[derive(Debug, Serialize, Deserialize)] #[serde(deny_unknown_fields)] struct NetworkEndpointDef { @@ -672,6 +696,21 @@ fn yaml_mcp_method( } fn to_proto(raw: PolicyFile) -> SandboxPolicy { + let network_middlewares = raw + .network_middlewares + .into_iter() + .map(|mw| NetworkMiddlewareConfig { + name: mw.name, + middleware: mw.middleware, + config: Some(json_map_to_struct(mw.config)), + on_error: mw.on_error, + endpoints: mw.endpoints.map(|selector| MiddlewareEndpointSelector { + include: selector.include, + exclude: selector.exclude, + }), + }) + .collect(); + let network_policies = raw .network_policies .into_iter() @@ -776,6 +815,7 @@ fn to_proto(raw: PolicyFile) -> SandboxPolicy { run_as_group: p.run_as_group, }), network_policies, + network_middlewares, } } @@ -908,12 +948,97 @@ fn from_proto(policy: &SandboxPolicy) -> PolicyFile { }) .collect(); + let network_middlewares = policy + .network_middlewares + .iter() + .map(|mw| NetworkMiddlewareConfigDef { + name: mw.name.clone(), + middleware: mw.middleware.clone(), + config: mw + .config + .as_ref() + .map(struct_to_json_map) + .unwrap_or_default(), + on_error: mw.on_error.clone(), + endpoints: mw + .endpoints + .as_ref() + .map(|selector| MiddlewareEndpointSelectorDef { + include: selector.include.clone(), + exclude: selector.exclude.clone(), + }), + }) + .collect(); + PolicyFile { version: policy.version, filesystem_policy, landlock, process, network_policies, + network_middlewares, + } +} + +fn json_map_to_struct(map: BTreeMap) -> prost_types::Struct { + prost_types::Struct { + fields: map + .into_iter() + .map(|(key, value)| (key, json_to_protobuf_value(value))) + .collect(), + } +} + +fn json_to_protobuf_value(value: serde_json::Value) -> prost_types::Value { + use prost_types::{ListValue, Struct, Value, value::Kind}; + Value { + kind: Some(match value { + serde_json::Value::Null => Kind::NullValue(0), + serde_json::Value::Bool(value) => Kind::BoolValue(value), + serde_json::Value::Number(value) => { + Kind::NumberValue(value.as_f64().unwrap_or_default()) + } + serde_json::Value::String(value) => Kind::StringValue(value), + serde_json::Value::Array(values) => Kind::ListValue(ListValue { + values: values.into_iter().map(json_to_protobuf_value).collect(), + }), + serde_json::Value::Object(values) => Kind::StructValue(Struct { + fields: values + .into_iter() + .map(|(key, value)| (key, json_to_protobuf_value(value))) + .collect(), + }), + }), + } +} + +fn struct_to_json_map(config: &prost_types::Struct) -> BTreeMap { + config + .fields + .iter() + .map(|(key, value)| (key.clone(), protobuf_value_to_json(value))) + .collect() +} + +fn protobuf_value_to_json(value: &prost_types::Value) -> serde_json::Value { + match value.kind.as_ref() { + Some(prost_types::value::Kind::NullValue(_)) | None => serde_json::Value::Null, + Some(prost_types::value::Kind::BoolValue(value)) => serde_json::Value::Bool(*value), + Some(prost_types::value::Kind::NumberValue(value)) => serde_json::Number::from_f64(*value) + .map_or(serde_json::Value::Null, serde_json::Value::Number), + Some(prost_types::value::Kind::StringValue(value)) => { + serde_json::Value::String(value.clone()) + } + Some(prost_types::value::Kind::ListValue(value)) => { + serde_json::Value::Array(value.values.iter().map(protobuf_value_to_json).collect()) + } + Some(prost_types::value::Kind::StructValue(value)) => serde_json::Value::Object( + value + .fields + .iter() + .map(|(key, value)| (key.clone(), protobuf_value_to_json(value))) + .collect(), + ), } } @@ -1029,6 +1154,7 @@ pub fn restrictive_default_policy() -> SandboxPolicy { run_as_group: "sandbox".into(), }), network_policies: HashMap::new(), + network_middlewares: vec![], } } @@ -1084,6 +1210,18 @@ pub enum PolicyViolation { }, /// `credential_signing` and `request_body_credential_rewrite` are both set. CredentialSigningWithBodyRewrite { policy_name: String, host: String }, + /// A built-in middleware configuration is invalid. + InvalidBuiltinMiddlewareConfig { name: String, reason: String }, + /// A middleware configuration is structurally invalid. + InvalidMiddlewareConfig { name: String, reason: String }, + /// Middleware configuration names must be unique. + DuplicateMiddlewareConfigName { name: String }, + /// A middleware selector conflicts with an endpoint that skips TLS inspection. + MiddlewareTlsSkipConflict { + middleware_name: String, + policy_name: String, + host: String, + }, } impl fmt::Display for PolicyViolation { @@ -1145,10 +1283,67 @@ impl fmt::Display for PolicyViolation { and request_body_credential_rewrite set; these options are mutually exclusive" ) } + Self::InvalidBuiltinMiddlewareConfig { name, reason } + | Self::InvalidMiddlewareConfig { name, reason } => { + write!(f, "middleware config '{name}' is invalid: {reason}") + } + Self::DuplicateMiddlewareConfigName { name } => { + write!(f, "duplicate middleware config '{name}'") + } + Self::MiddlewareTlsSkipConflict { + middleware_name, + policy_name, + host, + } => { + write!( + f, + "middleware config '{middleware_name}' selects network policy \ + '{policy_name}' tls: skip endpoint '{host}'" + ) + } } } } +/// Match a middleware host selector pattern using the runtime's glob semantics. +/// +/// Invalid or empty patterns return an error instead of silently becoming a +/// non-match. +pub fn middleware_host_matches(pattern: &str, host: &str) -> std::result::Result { + if pattern.is_empty() { + return Err("host pattern must not be empty".to_string()); + } + if pattern.chars().any(char::is_whitespace) { + return Err("host pattern must not contain whitespace".to_string()); + } + + let pattern = glob::Pattern::new(&pattern.to_ascii_lowercase()) + .map_err(|error| format!("invalid host pattern: {error}"))?; + Ok(pattern.matches(&host.to_ascii_lowercase())) +} + +fn middleware_selector_matches_host( + middleware: &NetworkMiddlewareConfig, + host: &str, +) -> std::result::Result { + let Some(selector) = &middleware.endpoints else { + return Ok(false); + }; + let matches_include = selector + .include + .iter() + .try_fold(false, |matched, pattern| { + middleware_host_matches(pattern, host).map(|matches| matched || matches) + })?; + let matches_exclude = selector + .exclude + .iter() + .try_fold(false, |matched, pattern| { + middleware_host_matches(pattern, host).map(|matches| matched || matches) + })?; + Ok(matches_include && !matches_exclude) +} + /// Validate that a sandbox policy does not contain unsafe content. /// /// Returns `Ok(())` if the policy is safe, or `Err(violations)` listing all @@ -1163,6 +1358,9 @@ impl fmt::Display for PolicyViolation { /// - Individual path lengths must not exceed [`MAX_PATH_LENGTH`] /// - Total path count must not exceed [`MAX_FILESYSTEM_PATHS`] /// - Network endpoint hosts must not use TLD wildcards (e.g. `*.com`) +/// - Middleware names, implementations, failure modes, selectors, and built-in +/// configurations must be valid +/// - Middleware selectors must not match endpoints that skip TLS inspection pub fn validate_sandbox_policy( policy: &SandboxPolicy, ) -> std::result::Result<(), Vec> { @@ -1276,6 +1474,98 @@ pub fn validate_sandbox_policy( } } + let mut middleware_names = HashSet::new(); + for middleware in &policy.network_middlewares { + if middleware.name.is_empty() { + violations.push(PolicyViolation::InvalidMiddlewareConfig { + name: middleware.name.clone(), + reason: "name must not be empty".to_string(), + }); + } else if !middleware_names.insert(middleware.name.clone()) { + violations.push(PolicyViolation::DuplicateMiddlewareConfigName { + name: middleware.name.clone(), + }); + } + + if middleware.middleware.is_empty() { + violations.push(PolicyViolation::InvalidMiddlewareConfig { + name: middleware.name.clone(), + reason: "implementation must not be empty".to_string(), + }); + } else if middleware.middleware.starts_with("openshell/") + && middleware.middleware != openshell_supervisor_middleware::BUILTIN_SECRETS + { + violations.push(PolicyViolation::InvalidMiddlewareConfig { + name: middleware.name.clone(), + reason: format!("unsupported built-in '{}'", middleware.middleware), + }); + } + + if !matches!( + middleware.on_error.as_str(), + "" | "fail_closed" | "fail_open" + ) { + violations.push(PolicyViolation::InvalidMiddlewareConfig { + name: middleware.name.clone(), + reason: format!("invalid on_error '{}'", middleware.on_error), + }); + } + + let Some(selector) = &middleware.endpoints else { + violations.push(PolicyViolation::InvalidMiddlewareConfig { + name: middleware.name.clone(), + reason: "endpoint selector is required".to_string(), + }); + continue; + }; + if selector.include.is_empty() { + violations.push(PolicyViolation::InvalidMiddlewareConfig { + name: middleware.name.clone(), + reason: "endpoint selector must include at least one host pattern".to_string(), + }); + } + for pattern in selector.include.iter().chain(&selector.exclude) { + if let Err(reason) = middleware_host_matches(pattern, "validation.invalid") { + violations.push(PolicyViolation::InvalidMiddlewareConfig { + name: middleware.name.clone(), + reason: format!("endpoint selector pattern '{pattern}' is invalid: {reason}"), + }); + } + } + + if middleware.middleware == openshell_supervisor_middleware::BUILTIN_SECRETS { + let config = middleware.config.clone().unwrap_or_default(); + if let Err(error) = openshell_supervisor_middleware::validate_builtin_config( + &middleware.middleware, + &config, + ) { + violations.push(PolicyViolation::InvalidBuiltinMiddlewareConfig { + name: middleware.name.clone(), + reason: error.to_string(), + }); + } + } + + for (key, rule) in &policy.network_policies { + let policy_name = if rule.name.is_empty() { + key + } else { + &rule.name + }; + for endpoint in &rule.endpoints { + if endpoint.tls == "skip" + && middleware_selector_matches_host(middleware, &endpoint.host).unwrap_or(false) + { + violations.push(PolicyViolation::MiddlewareTlsSkipConflict { + middleware_name: middleware.name.clone(), + policy_name: policy_name.clone(), + host: endpoint.host.clone(), + }); + } + } + } + } + if violations.is_empty() { Ok(()) } else { @@ -1399,6 +1689,70 @@ network_policies: assert_eq!(proto2.network_policies["my_api"].name, "my-custom-api-name"); } + #[test] + fn round_trip_preserves_network_middlewares() { + let yaml = r#" +version: 1 +network_middlewares: + - name: global-redactor + middleware: openshell/secrets + on_error: fail_open + endpoints: + include: ["api.example.com", "*.service.test"] + exclude: ["internal.example.com"] + config: + secrets: ["api_key", "authorization"] + service: + mode: redact + max_matches: 2 + - name: secondary-redactor + middleware: openshell/secrets + endpoints: + include: ["api.example.com"] +network_policies: + api: + name: api + endpoints: + - host: api.example.com + port: 443 + protocol: rest + binaries: + - path: /usr/bin/curl +"#; + let proto = parse_sandbox_policy(yaml).expect("parse failed"); + assert_eq!(proto.network_middlewares.len(), 2); + assert_eq!(proto.network_middlewares[0].name, "global-redactor"); + assert_eq!(proto.network_middlewares[0].middleware, "openshell/secrets"); + assert_eq!(proto.network_middlewares[0].on_error, "fail_open"); + assert_eq!( + proto.network_middlewares[0] + .endpoints + .as_ref() + .expect("selector") + .include, + vec!["api.example.com", "*.service.test"] + ); + assert_eq!( + proto.network_middlewares[0] + .endpoints + .as_ref() + .expect("selector") + .exclude, + vec!["internal.example.com"] + ); + assert!( + proto.network_middlewares[0] + .config + .as_ref() + .expect("config") + .fields + .contains_key("service") + ); + let yaml_out = serialize_sandbox_policy(&proto).expect("serialize failed"); + let reparsed = parse_sandbox_policy(&yaml_out).expect("re-parse failed"); + assert_eq!(reparsed.network_middlewares, proto.network_middlewares); + } + #[test] fn restrictive_default_has_no_network_policies() { let policy = restrictive_default_policy(); @@ -1529,6 +1883,31 @@ network_policies: assert!(parse_sandbox_policy(yaml).is_err()); } + #[test] + fn parse_rejects_middleware_attachments_on_network_policies_and_endpoints() { + let policy_attachment = r" +version: 1 +network_policies: + api: + middleware: [redact] + endpoints: + - host: api.example.com + port: 443 +"; + assert!(parse_sandbox_policy(policy_attachment).is_err()); + + let endpoint_attachment = r" +version: 1 +network_policies: + api: + endpoints: + - host: api.example.com + port: 443 + middleware: [redact] +"; + assert!(parse_sandbox_policy(endpoint_attachment).is_err()); + } + #[test] fn l7_config_stanza_runtime_fields_use_canonical_schema() { let fields = l7_config_alias_runtime_fields( @@ -1602,6 +1981,19 @@ network_policies: // ---- Policy validation tests ---- + fn middleware_config(name: &str, implementation: &str) -> NetworkMiddlewareConfig { + NetworkMiddlewareConfig { + name: name.into(), + middleware: implementation.into(), + config: None, + on_error: String::new(), + endpoints: Some(MiddlewareEndpointSelector { + include: vec!["api.example.com".into()], + exclude: Vec::new(), + }), + } + } + #[test] fn validate_rejects_root_run_as_user() { let mut policy = restrictive_default_policy(); @@ -1630,6 +2022,157 @@ network_policies: assert_eq!(violations.len(), 2); } + #[test] + fn validate_rejects_invalid_builtin_middleware_config() { + let mut policy = restrictive_default_policy(); + let mut middleware = middleware_config("redact-secrets", "openshell/secrets"); + middleware.config = Some(prost_types::Struct { + fields: std::iter::once(( + "secrets".into(), + prost_types::Value { + kind: Some(prost_types::value::Kind::StringValue("allow".into())), + }, + )) + .collect(), + }); + policy.network_middlewares.push(middleware); + + let violations = validate_sandbox_policy(&policy).expect_err("invalid config"); + assert!(violations.iter().any(|violation| matches!( + violation, + PolicyViolation::InvalidBuiltinMiddlewareConfig { name, .. } + if name == "redact-secrets" + ))); + } + + #[test] + fn validate_rejects_invalid_middleware_control_fields() { + let cases = [ + ( + middleware_config("", "openshell/secrets"), + "name must not be empty", + ), + ( + middleware_config("redactor", ""), + "implementation must not be empty", + ), + ( + middleware_config("redactor", "openshell/unknown"), + "unsupported built-in", + ), + ( + { + let mut middleware = middleware_config("redactor", "openshell/secrets"); + middleware.on_error = "maybe".into(); + middleware + }, + "invalid on_error", + ), + ( + { + let mut middleware = middleware_config("redactor", "openshell/secrets"); + middleware.endpoints = None; + middleware + }, + "endpoint selector is required", + ), + ( + { + let mut middleware = middleware_config("redactor", "openshell/secrets"); + middleware.endpoints.as_mut().unwrap().include.clear(); + middleware + }, + "must include at least one host pattern", + ), + ]; + + for (middleware, expected) in cases { + let mut policy = restrictive_default_policy(); + policy.network_middlewares.push(middleware); + let errors = validate_sandbox_policy(&policy) + .expect_err("invalid middleware must be rejected") + .into_iter() + .map(|violation| violation.to_string()) + .collect::>() + .join("; "); + assert!( + errors.contains(expected), + "expected {expected:?} in {errors:?}" + ); + } + } + + #[test] + fn validate_rejects_duplicate_middleware_config_names() { + let mut policy = restrictive_default_policy(); + policy + .network_middlewares + .push(middleware_config("redactor", "openshell/secrets")); + policy + .network_middlewares + .push(middleware_config("redactor", "openshell/secrets")); + + let violations = validate_sandbox_policy(&policy).expect_err("duplicate name"); + assert!(violations.iter().any(|violation| matches!( + violation, + PolicyViolation::DuplicateMiddlewareConfigName { name } if name == "redactor" + ))); + } + + #[test] + fn validate_rejects_malformed_middleware_selector_patterns() { + let mut policy = restrictive_default_policy(); + let mut middleware = middleware_config("redactor", "openshell/secrets"); + middleware.endpoints.as_mut().unwrap().include = vec!["api[.example.com".into()]; + policy.network_middlewares.push(middleware); + + let errors = validate_sandbox_policy(&policy) + .expect_err("malformed selector") + .into_iter() + .map(|violation| violation.to_string()) + .collect::>() + .join("; "); + assert!(errors.contains("invalid host pattern"), "{errors}"); + } + + #[test] + fn middleware_host_selector_matching_is_case_insensitive() { + assert!(middleware_host_matches("*.Example.COM", "API.example.com").unwrap()); + assert!(!middleware_host_matches("*.example.com", "example.com").unwrap()); + assert!(middleware_host_matches("*", "deep.api.example.com").unwrap()); + } + + #[test] + fn validate_rejects_middleware_selector_matching_tls_skip_endpoint() { + let mut policy = restrictive_default_policy(); + policy + .network_middlewares + .push(middleware_config("redactor", "openshell/secrets")); + policy.network_policies.insert( + "api".into(), + NetworkPolicyRule { + name: "api".into(), + endpoints: vec![NetworkEndpoint { + host: "api.example.com".into(), + port: 443, + tls: "skip".into(), + ..Default::default() + }], + binaries: Vec::new(), + }, + ); + + let violations = validate_sandbox_policy(&policy).expect_err("tls skip conflict"); + assert!(violations.iter().any(|violation| matches!( + violation, + PolicyViolation::MiddlewareTlsSkipConflict { + middleware_name, + policy_name, + host, + } if middleware_name == "redactor" && policy_name == "api" && host == "api.example.com" + ))); + } + #[test] fn validate_rejects_non_sandbox_user() { let mut policy = restrictive_default_policy(); @@ -1714,6 +2257,7 @@ network_policies: filesystem: None, landlock: None, network_policies: HashMap::new(), + network_middlewares: Vec::new(), }; assert!(validate_sandbox_policy(&policy).is_ok()); } diff --git a/crates/openshell-sandbox/Cargo.toml b/crates/openshell-sandbox/Cargo.toml index 086dbe02c..d3c3e7108 100644 --- a/crates/openshell-sandbox/Cargo.toml +++ b/crates/openshell-sandbox/Cargo.toml @@ -19,6 +19,7 @@ openshell-core = { path = "../openshell-core", default-features = false } openshell-ocsf = { path = "../openshell-ocsf" } openshell-policy = { path = "../openshell-policy" } openshell-supervisor-network = { path = "../openshell-supervisor-network" } +openshell-supervisor-middleware = { path = "../openshell-supervisor-middleware" } openshell-supervisor-process = { path = "../openshell-supervisor-process" } # Async runtime diff --git a/crates/openshell-sandbox/src/lib.rs b/crates/openshell-sandbox/src/lib.rs index d5967d1f3..c1531023a 100644 --- a/crates/openshell-sandbox/src/lib.rs +++ b/crates/openshell-sandbox/src/lib.rs @@ -1376,12 +1376,12 @@ async fn load_policy( endpoint = %endpoint, "Fetching sandbox policy via gRPC" ); - let proto_policy = grpc_retry("Policy fetch", || { - openshell_core::grpc_client::fetch_policy(endpoint, id) + let mut sandbox_config = grpc_retry("Policy fetch", || { + openshell_core::grpc_client::fetch_sandbox_config(endpoint, id) }) .await?; - let mut proto_policy = if let Some(p) = proto_policy { + let mut proto_policy = if let Some(p) = sandbox_config.policy.take() { p } else { // No policy configured on the server. Discover from disk or @@ -1409,7 +1409,7 @@ async fn load_policy( // Sync and re-fetch over a single connection to avoid extra // TLS handshakes. - grpc_retry("Policy discovery sync", || { + let synced = grpc_retry("Policy discovery sync", || { openshell_core::grpc_client::discover_and_sync_policy( endpoint, id, @@ -1417,7 +1417,12 @@ async fn load_policy( &discovered, ) }) - .await? + .await?; + sandbox_config = grpc_retry("Policy refetch after discovery", || { + openshell_core::grpc_client::fetch_sandbox_config(endpoint, id) + }) + .await?; + sandbox_config.policy.take().unwrap_or(synced) }; // Ensure baseline filesystem paths are present for proxy-mode @@ -1443,7 +1448,14 @@ async fn load_policy( // container hasn't started yet. After the entrypoint spawns, the // engine is rebuilt with the real PID for symlink resolution. info!("Creating OPA engine from proto policy data"); - let opa_engine = Some(Arc::new(OpaEngine::from_proto(&proto_policy)?)); + let engine = OpaEngine::from_proto(&proto_policy)?; + let middleware_registry = + openshell_supervisor_middleware::MiddlewareRegistry::connect_services( + sandbox_config.supervisor_middleware_services, + ) + .await?; + engine.replace_middleware_registry(middleware_registry)?; + let opa_engine = Some(Arc::new(engine)); let policy = SandboxPolicy::try_from(proto_policy.clone())?; return Ok((policy, opa_engine, Some(proto_policy))); @@ -1593,6 +1605,7 @@ async fn run_policy_poll_loop(ctx: PolicyPollLoopContext) -> Result<()> { let mut current_config_revision: u64 = 0; let mut current_provider_env_revision: u64 = ctx.provider_credentials.snapshot().revision; let mut current_policy_hash = String::new(); + let mut current_middleware_services = Vec::new(); let mut current_settings: std::collections::HashMap< String, openshell_core::proto::EffectiveSetting, @@ -1604,6 +1617,7 @@ async fn run_policy_poll_loop(ctx: PolicyPollLoopContext) -> Result<()> { apply_ocsf_json_setting(&ctx.ocsf_enabled, &result.settings); current_config_revision = result.config_revision; current_policy_hash = result.policy_hash.clone(); + current_middleware_services = result.supervisor_middleware_services; current_settings = result.settings; debug!( config_revision = current_config_revision, @@ -1633,6 +1647,8 @@ async fn run_policy_poll_loop(ctx: PolicyPollLoopContext) -> Result<()> { } let policy_changed = result.policy_hash != current_policy_hash; + let middleware_changed = + result.supervisor_middleware_services != current_middleware_services; // Log which settings changed. log_setting_changes(¤t_settings, &result.settings); @@ -1691,6 +1707,47 @@ async fn run_policy_poll_loop(ctx: PolicyPollLoopContext) -> Result<()> { } } + if middleware_changed { + match openshell_supervisor_middleware::MiddlewareRegistry::connect_services( + result.supervisor_middleware_services.clone(), + ) + .await + .and_then(|registry| ctx.opa_engine.replace_middleware_registry(registry)) + { + Ok(()) => { + current_middleware_services = result.supervisor_middleware_services.clone(); + ocsf_emit!( + ConfigStateChangeBuilder::new(ocsf_ctx()) + .severity(SeverityId::Informational) + .status(StatusId::Success) + .state(StateId::Enabled, "loaded") + .unmapped( + "supervisor_middleware_service_count", + serde_json::json!(current_middleware_services.len()) + ) + .message(format!( + "Supervisor middleware registry reloaded [service_count:{}]", + current_middleware_services.len() + )) + .build() + ); + } + Err(error) => { + ocsf_emit!( + ConfigStateChangeBuilder::new(ocsf_ctx()) + .severity(SeverityId::Medium) + .status(StatusId::Failure) + .state(StateId::Other, "failed") + .message(format!( + "Supervisor middleware registry reload failed, keeping last-known-good registry [error:{error}]" + )) + .build() + ); + continue; + } + } + } + // Only reload OPA when the policy payload actually changed. if policy_changed { let Some(policy) = result.policy.as_ref() else { diff --git a/crates/openshell-server/Cargo.toml b/crates/openshell-server/Cargo.toml index b5c9b34d7..fafc72ba7 100644 --- a/crates/openshell-server/Cargo.toml +++ b/crates/openshell-server/Cargo.toml @@ -26,6 +26,7 @@ openshell-prover = { path = "../openshell-prover" } openshell-providers = { path = "../openshell-providers" } openshell-router = { path = "../openshell-router" } openshell-server-macros = { path = "../openshell-server-macros" } +openshell-supervisor-middleware = { path = "../openshell-supervisor-middleware" } # Kubernetes client (used by the `generate-certs` subcommand) kube = { workspace = true } diff --git a/crates/openshell-server/src/config_file.rs b/crates/openshell-server/src/config_file.rs index 3875756dc..ca0987442 100644 --- a/crates/openshell-server/src/config_file.rs +++ b/crates/openshell-server/src/config_file.rs @@ -25,6 +25,7 @@ use std::net::SocketAddr; use std::path::{Path, PathBuf}; use openshell_core::config::ComputeDriverKind; +use openshell_core::proto::SupervisorMiddlewareService; use openshell_core::{GatewayAuthConfig, GatewayJwtConfig, MtlsAuthConfig, OidcConfig, TlsConfig}; use serde::{Deserialize, Serialize}; @@ -151,6 +152,12 @@ pub struct GatewayFileSection { #[serde(default)] pub gateway_jwt: Option, + // ── Supervisor middleware ───────────────────────────────────────────── + /// Statically registered supervisor middleware services. Registration is + /// operator-owned and changes require a gateway restart. + #[serde(default)] + pub middleware: Vec, + // ── Disallowed-in-file fields ──────────────────────────────────────── // // Captured so we can produce a friendly "set this via env/CLI instead" @@ -160,6 +167,32 @@ pub struct GatewayFileSection { pub database_url: Option, } +/// One `[[openshell.gateway.middleware]]` supervisor middleware registration. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(deny_unknown_fields)] +pub struct MiddlewareServiceFileConfig { + /// Operator-facing name used for diagnostics. + pub name: String, + /// Plaintext gRPC endpoint reachable by the gateway and supervisors. + pub endpoint: String, + /// Required explicit opt-in to the local-development-only insecure mode. + #[serde(default)] + pub allow_insecure: bool, + /// Operator-owned body limit for every binding exposed by this service. + pub max_body_bytes: u64, +} + +impl From<&MiddlewareServiceFileConfig> for SupervisorMiddlewareService { + fn from(config: &MiddlewareServiceFileConfig) -> Self { + Self { + name: config.name.clone(), + endpoint: config.endpoint.clone(), + allow_insecure: config.allow_insecure, + max_body_bytes: config.max_body_bytes, + } + } +} + #[derive(Debug, thiserror::Error)] pub enum ConfigFileError { #[error("failed to read gateway config file '{}': {source}", path.display())] @@ -400,6 +433,28 @@ allow_unauthenticated_users = true assert!(auth.allow_unauthenticated_users); } + #[test] + fn parses_supervisor_middleware_registration() { + let toml = r#" +[[openshell.gateway.middleware]] +name = "local-guard" +endpoint = "http://127.0.0.1:50051" +allow_insecure = true +max_body_bytes = 262144 +"#; + let tmp = write_tmp(toml); + let file = load(tmp.path()).expect("valid middleware registration parses"); + assert_eq!( + file.openshell.gateway.middleware, + vec![MiddlewareServiceFileConfig { + name: "local-guard".into(), + endpoint: "http://127.0.0.1:50051".into(), + allow_insecure: true, + max_body_bytes: 262_144, + }] + ); + } + #[test] fn rejects_database_url_in_file() { let toml = r#" diff --git a/crates/openshell-server/src/grpc/policy.rs b/crates/openshell-server/src/grpc/policy.rs index cc8ff0d2e..9587e7295 100644 --- a/crates/openshell-server/src/grpc/policy.rs +++ b/crates/openshell-server/src/grpc/policy.rs @@ -1218,8 +1218,26 @@ pub(super) async fn handle_get_sandbox_config( } } + if let Some(policy) = policy.as_ref() { + state + .middleware_registry + .ensure_policy_bindings_registered(policy) + .map_err(|error| { + Status::failed_precondition(format!( + "effective policy middleware registration is invalid: {error}" + )) + })?; + } + let settings = merge_effective_settings(&global_settings, &sandbox_settings)?; - let config_revision = compute_config_revision(policy.as_ref(), &settings, policy_source); + let supervisor_middleware_services = + state.middleware_registry.required_services(policy.as_ref()); + let config_revision = compute_config_revision( + policy.as_ref(), + &settings, + policy_source, + &supervisor_middleware_services, + ); let provider_env_revision = compute_provider_env_revision(state.store.as_ref(), &sandbox_provider_names).await?; @@ -1232,6 +1250,7 @@ pub(super) async fn handle_get_sandbox_config( policy_source: policy_source.into(), global_policy_version, provider_env_revision, + supervisor_middleware_services, })) } @@ -1510,6 +1529,8 @@ async fn handle_update_config_inner( openshell_policy::ensure_sandbox_process_identity(&mut new_policy); validate_no_reserved_provider_policy_keys(&new_policy)?; validate_policy_safety(&new_policy)?; + crate::middleware::validate_policy(state.middleware_registry.as_ref(), &new_policy) + .await?; let payload = new_policy.encode_to_vec(); let hash = deterministic_policy_hash(&new_policy); @@ -1827,9 +1848,11 @@ async fn handle_update_config_inner( validate_no_reserved_provider_policy_keys(&new_policy)?; } + validate_policy_safety(&new_policy)?; + crate::middleware::validate_policy(state.middleware_registry.as_ref(), &new_policy).await?; + if let Some(baseline_policy) = spec.policy.as_ref() { validate_static_fields_unchanged(baseline_policy, &new_policy)?; - validate_policy_safety(&new_policy)?; } else { // Backfill spec.policy using CAS (first-time policy discovery) let _sandbox_sync_guard = state.compute.sandbox_sync_guard().await; @@ -3120,6 +3143,7 @@ fn compute_config_revision( policy: Option<&ProtoSandboxPolicy>, settings: &HashMap, policy_source: PolicySource, + supervisor_middleware_services: &[openshell_core::proto::SupervisorMiddlewareService], ) -> u64 { let mut hasher = Sha256::new(); hasher.update((policy_source as i32).to_le_bytes()); @@ -3152,6 +3176,11 @@ fn compute_config_revision( } } } + let mut middleware = supervisor_middleware_services.iter().collect::>(); + middleware.sort_by(|left, right| left.name.cmp(&right.name)); + for service in middleware { + hasher.update(service.encode_to_vec()); + } let digest = hasher.finalize(); let mut bytes = [0_u8; 8]; @@ -8773,7 +8802,7 @@ mod tests { allowed_ips: vec!["127.0.0.1".to_string()], ..Default::default() }], - binaries: vec![], + ..Default::default() }; let result = validate_rule_not_always_blocked(&rule); assert!(result.is_err()); @@ -8794,7 +8823,7 @@ mod tests { allowed_ips: vec!["169.254.169.254".to_string()], ..Default::default() }], - binaries: vec![], + ..Default::default() }; let result = validate_rule_not_always_blocked(&rule); assert!(result.is_err()); @@ -8812,7 +8841,7 @@ mod tests { port: 80, ..Default::default() }], - binaries: vec![], + ..Default::default() }; let result = validate_rule_not_always_blocked(&rule); assert!(result.is_err()); @@ -8830,7 +8859,7 @@ mod tests { port: 8080, ..Default::default() }], - binaries: vec![], + ..Default::default() }; let result = validate_rule_not_always_blocked(&rule); assert!(result.is_err()); @@ -8848,7 +8877,7 @@ mod tests { port: 80, ..Default::default() }], - binaries: vec![], + ..Default::default() }; let result = validate_rule_not_always_blocked(&rule); assert!(result.is_err()); @@ -8896,7 +8925,7 @@ mod tests { allowed_ips: vec!["10.0.5.0/24".to_string()], ..Default::default() }], - binaries: vec![], + ..Default::default() }; let result = validate_rule_not_always_blocked(&rule); assert!(result.is_ok()); @@ -8913,7 +8942,7 @@ mod tests { port: 443, ..Default::default() }], - binaries: vec![], + ..Default::default() }; let result = validate_rule_not_always_blocked(&rule); assert!(result.is_ok()); @@ -8989,7 +9018,7 @@ mod tests { }, ); - let rev_a = compute_config_revision(Some(&policy), &settings, PolicySource::Sandbox); + let rev_a = compute_config_revision(Some(&policy), &settings, PolicySource::Sandbox, &[]); settings.insert( "mode".to_string(), EffectiveSetting { @@ -8999,7 +9028,7 @@ mod tests { scope: SettingScope::Sandbox.into(), }, ); - let rev_b = compute_config_revision(Some(&policy), &settings, PolicySource::Sandbox); + let rev_b = compute_config_revision(Some(&policy), &settings, PolicySource::Sandbox, &[]); assert_ne!(rev_a, rev_b); } @@ -9264,8 +9293,8 @@ mod tests { }, ); - let rev_a = compute_config_revision(Some(&policy), &settings, PolicySource::Sandbox); - let rev_b = compute_config_revision(Some(&policy), &settings, PolicySource::Sandbox); + let rev_a = compute_config_revision(Some(&policy), &settings, PolicySource::Sandbox, &[]); + let rev_b = compute_config_revision(Some(&policy), &settings, PolicySource::Sandbox, &[]); assert_eq!(rev_a, rev_b); } @@ -9281,8 +9310,8 @@ mod tests { }; let settings = HashMap::new(); - let rev_a = compute_config_revision(Some(&policy_a), &settings, PolicySource::Sandbox); - let rev_b = compute_config_revision(Some(&policy_b), &settings, PolicySource::Sandbox); + let rev_a = compute_config_revision(Some(&policy_a), &settings, PolicySource::Sandbox, &[]); + let rev_b = compute_config_revision(Some(&policy_b), &settings, PolicySource::Sandbox, &[]); assert_ne!(rev_a, rev_b); } @@ -9291,11 +9320,28 @@ mod tests { let policy = ProtoSandboxPolicy::default(); let settings = HashMap::new(); - let rev_a = compute_config_revision(Some(&policy), &settings, PolicySource::Sandbox); - let rev_b = compute_config_revision(Some(&policy), &settings, PolicySource::Global); + let rev_a = compute_config_revision(Some(&policy), &settings, PolicySource::Sandbox, &[]); + let rev_b = compute_config_revision(Some(&policy), &settings, PolicySource::Global, &[]); assert_ne!(rev_a, rev_b); } + #[test] + fn config_revision_changes_when_supervisor_middleware_services_change() { + let policy = ProtoSandboxPolicy::default(); + let settings = HashMap::new(); + let service = openshell_core::proto::SupervisorMiddlewareService { + name: "local-guard".into(), + endpoint: "http://127.0.0.1:50051".into(), + allow_insecure: true, + max_body_bytes: 1024, + }; + + let without = compute_config_revision(Some(&policy), &settings, PolicySource::Sandbox, &[]); + let with = + compute_config_revision(Some(&policy), &settings, PolicySource::Sandbox, &[service]); + assert_ne!(without, with); + } + #[test] fn config_revision_without_policy_still_hashes_settings() { let mut settings = HashMap::new(); @@ -9309,7 +9355,7 @@ mod tests { }, ); - let rev_a = compute_config_revision(None, &settings, PolicySource::Sandbox); + let rev_a = compute_config_revision(None, &settings, PolicySource::Sandbox, &[]); settings.insert( "log_level".to_string(), @@ -9321,7 +9367,7 @@ mod tests { }, ); - let rev_b = compute_config_revision(None, &settings, PolicySource::Sandbox); + let rev_b = compute_config_revision(None, &settings, PolicySource::Sandbox, &[]); assert_ne!(rev_a, rev_b); } diff --git a/crates/openshell-server/src/grpc/sandbox.rs b/crates/openshell-server/src/grpc/sandbox.rs index 04d5a4ed5..203cd7dbe 100644 --- a/crates/openshell-server/src/grpc/sandbox.rs +++ b/crates/openshell-server/src/grpc/sandbox.rs @@ -164,6 +164,7 @@ async fn handle_create_sandbox_inner( openshell_policy::ensure_sandbox_process_identity(policy); validate_no_reserved_provider_policy_keys(policy)?; validate_policy_safety(policy)?; + crate::middleware::validate_policy(state.middleware_registry.as_ref(), policy).await?; } let id = uuid::Uuid::new_v4().to_string(); diff --git a/crates/openshell-server/src/grpc/validation.rs b/crates/openshell-server/src/grpc/validation.rs index 09e9f1cad..af6b84af6 100644 --- a/crates/openshell-server/src/grpc/validation.rs +++ b/crates/openshell-server/src/grpc/validation.rs @@ -1621,6 +1621,28 @@ mod tests { assert!(err.message().contains("TLD wildcard")); } + #[test] + fn validate_policy_safety_rejects_invalid_middleware_before_acceptance() { + use openshell_core::proto::{MiddlewareEndpointSelector, NetworkMiddlewareConfig}; + + let mut policy = openshell_policy::restrictive_default_policy(); + policy.network_middlewares.push(NetworkMiddlewareConfig { + name: "redactor".into(), + middleware: "openshell/secrets".into(), + on_error: "maybe".into(), + endpoints: Some(MiddlewareEndpointSelector { + include: vec!["api[.example.com".into()], + exclude: Vec::new(), + }), + ..Default::default() + }); + + let err = validate_policy_safety(&policy).unwrap_err(); + assert_eq!(err.code(), Code::InvalidArgument); + assert!(err.message().contains("invalid on_error")); + assert!(err.message().contains("invalid host pattern")); + } + #[test] fn validate_no_reserved_provider_policy_keys_rejects_reserved_key() { use openshell_core::proto::NetworkPolicyRule; diff --git a/crates/openshell-server/src/lib.rs b/crates/openshell-server/src/lib.rs index d0dbb1681..cc318af87 100644 --- a/crates/openshell-server/src/lib.rs +++ b/crates/openshell-server/src/lib.rs @@ -32,6 +32,7 @@ mod defaults; mod grpc; mod http; mod inference; +mod middleware; mod multiplex; mod persistence; pub(crate) mod policy_store; @@ -53,6 +54,7 @@ mod ws_tunnel; use metrics_exporter_prometheus::PrometheusBuilder; use openshell_core::{ComputeDriverKind, Config, Error, Result}; +use openshell_supervisor_middleware::MiddlewareRegistry; use serde::Deserialize; use std::collections::HashMap; use std::io::ErrorKind; @@ -123,6 +125,9 @@ pub struct ServerState { /// query session state to surface supervisor readiness. pub supervisor_sessions: Arc, + /// Validated built-in and operator-registered supervisor middleware. + pub middleware_registry: Arc, + /// OIDC JWKS cache for JWT validation. `None` when OIDC is not configured. pub oidc_cache: Option>, @@ -189,6 +194,7 @@ impl ServerState { ssh_connections_by_sandbox: Mutex::new(HashMap::new()), settings_mutex: tokio::sync::Mutex::new(()), supervisor_sessions, + middleware_registry: Arc::new(MiddlewareRegistry::default()), oidc_cache, sandbox_jwt_issuer: None, sandbox_jwt_authenticator: None, @@ -217,6 +223,23 @@ pub async fn run_server( return Err(Error::config("database_url is required")); } + let middleware_registrations = config_file + .as_ref() + .map(|file| { + file.openshell + .gateway + .middleware + .iter() + .map(Into::into) + .collect() + }) + .unwrap_or_default(); + let middleware_registry = Arc::new( + MiddlewareRegistry::connect_services(middleware_registrations) + .await + .map_err(|error| Error::config(format!("middleware registration failed: {error}")))?, + ); + let store = Arc::new(Store::connect(database_url).await?); let oidc_cache = if let Some(ref oidc) = config.oidc { @@ -262,6 +285,7 @@ pub async fn run_server( supervisor_sessions, oidc_cache, ); + state.middleware_registry = middleware_registry; // Load the gateway-minted sandbox JWT signing key when configured. // Optional so single-driver dev deployments without certgen continue diff --git a/crates/openshell-server/src/middleware.rs b/crates/openshell-server/src/middleware.rs new file mode 100644 index 000000000..4c94f021a --- /dev/null +++ b/crates/openshell-server/src/middleware.rs @@ -0,0 +1,43 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use openshell_core::proto::SandboxPolicy; +use openshell_supervisor_middleware::MiddlewareRegistry; +use tonic::Status; + +/// Validate implementation-owned middleware config before accepting a policy. +pub async fn validate_policy( + registry: &MiddlewareRegistry, + policy: &SandboxPolicy, +) -> Result<(), Status> { + registry + .validate_policy_configs(policy) + .await + .map_err(|error| { + Status::invalid_argument(format!("policy middleware validation failed: {error}")) + }) +} + +#[cfg(test)] +mod tests { + use super::*; + use openshell_core::proto::NetworkMiddlewareConfig; + + #[tokio::test] + async fn unregistered_external_binding_is_rejected_before_admission() { + let policy = SandboxPolicy { + network_middlewares: vec![NetworkMiddlewareConfig { + name: "guard".into(), + middleware: "example/content-guard".into(), + ..Default::default() + }], + ..Default::default() + }; + + let error = validate_policy(&MiddlewareRegistry::default(), &policy) + .await + .expect_err("unregistered binding must fail"); + assert_eq!(error.code(), tonic::Code::InvalidArgument); + assert!(error.message().contains("not registered")); + } +} diff --git a/crates/openshell-supervisor-middleware/Cargo.toml b/crates/openshell-supervisor-middleware/Cargo.toml new file mode 100644 index 000000000..e5e53618d --- /dev/null +++ b/crates/openshell-supervisor-middleware/Cargo.toml @@ -0,0 +1,26 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +[package] +name = "openshell-supervisor-middleware" +description = "In-process supervisor middleware contract and built-ins for OpenShell" +version.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true +rust-version.workspace = true + +[dependencies] +openshell-core = { path = "../openshell-core" } + +miette = { workspace = true } +prost-types = { workspace = true } +regex = { workspace = true } +tokio = { workspace = true } +tonic = { workspace = true, features = ["channel", "server"] } + +[dev-dependencies] +tokio-stream = { workspace = true, features = ["net"] } + +[lints] +workspace = true diff --git a/crates/openshell-supervisor-middleware/src/builtins/mod.rs b/crates/openshell-supervisor-middleware/src/builtins/mod.rs new file mode 100644 index 000000000..1db620220 --- /dev/null +++ b/crates/openshell-supervisor-middleware/src/builtins/mod.rs @@ -0,0 +1,29 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +pub mod secrets; + +use miette::{Result, miette}; +use openshell_core::proto::{HttpRequestEvaluation, HttpRequestResult, MiddlewareBinding}; + +pub fn describe() -> Vec { + vec![secrets::describe()] +} + +pub fn validate_config(binding_id: &str, config: &prost_types::Struct) -> Result<()> { + match binding_id { + secrets::BINDING_ID => secrets::validate_config(config), + other => Err(miette!( + "middleware implementation '{other}' is not available in phase 1" + )), + } +} + +pub fn evaluate_http_request(evaluation: &HttpRequestEvaluation) -> Result { + match evaluation.binding_id.as_str() { + secrets::BINDING_ID => secrets::evaluate_http_request(evaluation), + other => Err(miette!( + "middleware implementation '{other}' is not available in phase 1" + )), + } +} diff --git a/crates/openshell-supervisor-middleware/src/builtins/secrets.rs b/crates/openshell-supervisor-middleware/src/builtins/secrets.rs new file mode 100644 index 000000000..d88ac080d --- /dev/null +++ b/crates/openshell-supervisor-middleware/src/builtins/secrets.rs @@ -0,0 +1,132 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use std::collections::HashMap; +use std::sync::LazyLock; + +use miette::{Result, miette}; +use openshell_core::proto::{ + Decision, Finding, HttpRequestEvaluation, HttpRequestResult, MiddlewareBinding, +}; +use regex::Regex; + +pub const BINDING_ID: &str = "openshell/secrets"; +const OPERATION: &str = "HttpRequest"; +const PHASE: &str = "pre_credentials"; +const MAX_BODY_BYTES: u64 = 256 * 1024; + +pub fn describe() -> MiddlewareBinding { + MiddlewareBinding { + id: BINDING_ID.into(), + operation: OPERATION.into(), + phase: PHASE.into(), + max_body_bytes: MAX_BODY_BYTES, + } +} + +/// A named secret-detection pattern. The `kind` is an audit-safe label that +/// flows into findings so operators can see *what* matched without seeing the +/// raw value. +struct SecretPattern { + kind: &'static str, + regex: Regex, +} + +impl SecretPattern { + fn new(kind: &'static str, pattern: &str) -> Self { + Self { + kind, + regex: Regex::new(pattern).expect("valid built-in secret redaction pattern"), + } + } +} + +/// Compiled once: recompiling per request would put regex construction on the +/// egress hot path. +static SECRET_PATTERNS: LazyLock<[SecretPattern; 2]> = LazyLock::new(|| { + [ + SecretPattern::new( + "keyword", + r#"(?i)(api[_-]?key|access[_-]?token|secret|password)(["']?\s*[:=]\s*["'])[^"',\s}]+(["']?)"#, + ), + SecretPattern::new("openai", r"(sk-[A-Za-z0-9_-]{16,})"), + ] +}); + +pub fn validate_config(config: &prost_types::Struct) -> Result<()> { + let mode = config + .fields + .get("secrets") + .and_then(|value| match value.kind.as_ref() { + Some(prost_types::value::Kind::StringValue(value)) => Some(value.as_str()), + _ => None, + }) + .unwrap_or("redact"); + if mode != "redact" { + return Err(miette!( + "{} only supports config.secrets: redact in phase 1", + BINDING_ID + )); + } + Ok(()) +} + +pub fn evaluate_http_request(evaluation: &HttpRequestEvaluation) -> Result { + let default_config = prost_types::Struct::default(); + validate_config(evaluation.config.as_ref().unwrap_or(&default_config))?; + let text = String::from_utf8(evaluation.body.clone()) + .map_err(|_| miette!("{} requires UTF-8 request bodies", BINDING_ID))?; + let (body, matches) = redact_common_secrets(&text); + let total: u32 = matches + .iter() + .fold(0u32, |acc, (_, count)| acc.saturating_add(*count)); + let mut result = HttpRequestResult { + decision: Decision::Allow as i32, + reason: String::new(), + body: body.into_bytes(), + has_body: !matches.is_empty(), + add_headers: HashMap::new(), + findings: Vec::new(), + metadata: HashMap::new(), + }; + if !matches.is_empty() { + // One finding per matched pattern kind, so audit shows what matched. + for (kind, count) in &matches { + result.findings.push(Finding { + r#type: format!("secret.{kind}"), + label: format!("{kind} secret pattern"), + count: *count, + confidence: "medium".into(), + severity: "medium".into(), + }); + } + result + .metadata + .insert("secrets_redacted".into(), total.to_string()); + } + Ok(result) +} + +/// Redact every configured secret pattern, returning the transformed text and +/// the per-kind match counts (only kinds that matched are included). +fn redact_common_secrets(input: &str) -> (String, Vec<(&'static str, u32)>) { + let mut output = input.to_string(); + let mut matches = Vec::new(); + for pattern in SECRET_PATTERNS.iter() { + let count = u32::try_from(pattern.regex.find_iter(&output).count()).unwrap_or(u32::MAX); + if count > 0 { + matches.push((pattern.kind, count)); + } + output = pattern + .regex + .replace_all(&output, |captures: ®ex::Captures<'_>| { + if captures.len() >= 4 { + format!("{}{}[REDACTED]{}", &captures[1], &captures[2], &captures[3]) + } else { + "[REDACTED]".to_string() + } + }) + .into_owned(); + } + (output, matches) +} diff --git a/crates/openshell-supervisor-middleware/src/lib.rs b/crates/openshell-supervisor-middleware/src/lib.rs new file mode 100644 index 000000000..ebc5817e4 --- /dev/null +++ b/crates/openshell-supervisor-middleware/src/lib.rs @@ -0,0 +1,1533 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Supervisor middleware registration and chain execution. + +mod builtins; +mod remote; +mod service; + +use std::collections::{BTreeMap, HashMap, HashSet}; +use std::sync::{Arc, LazyLock}; + +use miette::{Result, miette}; +pub use service::InProcessMiddlewareService; + +use openshell_core::proto::middleware::v1::supervisor_middleware_server::SupervisorMiddleware; +use openshell_core::proto::{ + Decision, Finding, HttpRequestEvaluation, HttpRequestTarget, MiddlewareBinding, + MiddlewareManifest, NetworkMiddlewareConfig, RequestContext, SandboxPolicy, + SupervisorMiddlewareService, ValidateConfigRequest, +}; +use tokio::sync::OnceCell; +use tonic::Request; + +pub const API_VERSION: &str = "openshell.middleware.v1"; +pub const BUILTIN_SECRETS: &str = builtins::secrets::BINDING_ID; +const HTTP_REQUEST_OPERATION: &str = "HttpRequest"; +const PRE_CREDENTIALS_PHASE: &str = "pre_credentials"; + +/// Validate the configuration for an in-process middleware implementation. +/// +/// Policy admission uses this same implementation-specific validation before a +/// configuration can reach the request path. +pub fn validate_builtin_config(implementation: &str, config: &prost_types::Struct) -> Result<()> { + builtins::validate_config(implementation, config) +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum OnError { + FailClosed, + FailOpen, +} + +impl OnError { + pub fn parse(value: &str) -> Result { + match value { + "" | "fail_closed" => Ok(Self::FailClosed), + "fail_open" => Ok(Self::FailOpen), + other => Err(miette!( + "invalid middleware on_error '{other}', expected fail_closed or fail_open" + )), + } + } +} + +#[derive(Debug, Clone)] +pub struct ChainEntry { + pub name: String, + pub implementation: String, + pub config: prost_types::Struct, + pub on_error: OnError, +} + +impl TryFrom<&NetworkMiddlewareConfig> for ChainEntry { + type Error = miette::Report; + + fn try_from(value: &NetworkMiddlewareConfig) -> Result { + if value.name.is_empty() { + return Err(miette!("middleware config name cannot be empty")); + } + if value.middleware.is_empty() { + return Err(miette!( + "middleware config '{}' must name an implementation", + value.name + )); + } + Ok(Self { + name: value.name.clone(), + implementation: value.middleware.clone(), + config: value.config.clone().unwrap_or_default(), + on_error: OnError::parse(&value.on_error)?, + }) + } +} + +/// A policy-selected middleware config joined with metadata reported by its +/// service's `Describe` call. A missing binding is retained so `on_error` can +/// decide whether the request fails open or closed. +#[derive(Clone)] +pub struct DescribedChainEntry { + entry: ChainEntry, + service: Option>, + binding: Option, + max_body_bytes: usize, +} + +impl DescribedChainEntry { + pub fn max_body_bytes(&self) -> usize { + self.max_body_bytes + } + + pub fn on_error(&self) -> OnError { + self.entry.on_error + } +} + +#[derive(Debug, Clone)] +pub struct HttpRequestInput { + pub request_id: String, + pub sandbox_id: String, + pub scheme: String, + pub host: String, + pub port: u16, + pub method: String, + pub path: String, + pub query: String, + pub headers: BTreeMap, + pub body: Vec, +} + +#[derive(Debug, Clone)] +pub struct ChainOutcome { + pub allowed: bool, + pub reason: String, + pub body: Vec, + pub added_headers: BTreeMap, + pub findings: Vec, + pub metadata: BTreeMap>, + pub applied: Vec, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct NamespacedFinding { + pub middleware: String, + pub finding: Finding, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct MiddlewareInvocation { + pub name: String, + pub implementation: String, + pub decision: Decision, + pub transformed: bool, + /// True when the middleware could not be evaluated and `on_error` was applied + /// (service error, malformed/unsafe response, etc.). The `decision` reflects + /// the `on_error` outcome, not a decision the middleware actually returned. + pub failed: bool, +} + +enum OnErrorAction { + /// `fail_open`: skip this middleware, leaving the request unchanged. + FailOpen, + /// `fail_closed`: short-circuit the chain and deny with the given reason. + FailClosed(String), +} + +/// Apply a middleware entry's `on_error` policy after a failure (service error or +/// malformed response). Records a `failed` invocation for telemetry in both cases. +fn apply_on_error( + entry: &DescribedChainEntry, + reason: &str, + applied: &mut Vec, +) -> OnErrorAction { + match entry.entry.on_error { + OnError::FailOpen => { + applied.push(MiddlewareInvocation { + name: entry.entry.name.clone(), + implementation: entry.entry.implementation.clone(), + decision: Decision::Allow, + transformed: false, + failed: true, + }); + OnErrorAction::FailOpen + } + OnError::FailClosed => { + applied.push(MiddlewareInvocation { + name: entry.entry.name.clone(), + implementation: entry.entry.implementation.clone(), + decision: Decision::Deny, + transformed: false, + failed: true, + }); + OnErrorAction::FailClosed(format!("middleware_failed: {reason}")) + } + } +} + +#[derive(Clone)] +pub struct ChainRunner { + registry: Arc, +} + +struct MiddlewareServiceState { + service: Arc, + manifest: OnceCell, + operator_max_body_bytes: Option, +} + +static IN_PROCESS_SERVICE: LazyLock> = LazyLock::new(|| { + Arc::new(MiddlewareServiceState { + service: Arc::new(InProcessMiddlewareService), + manifest: OnceCell::new(), + operator_max_body_bytes: None, + }) +}); + +/// Validated middleware services available to a gateway or one supervisor. +/// +/// The registry always contains the in-process built-ins. Operator-registered +/// services are connected and described before construction succeeds, so callers never +/// observe a partially registered service set. +#[derive(Clone)] +pub struct MiddlewareRegistry { + services: Arc>>, + registered_services: Arc>, +} + +impl std::fmt::Debug for MiddlewareRegistry { + fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter + .debug_struct("MiddlewareRegistry") + .field("service_count", &self.services.len()) + .field("registered_service_count", &self.registered_services.len()) + .finish() + } +} + +#[derive(Clone)] +struct RegisteredMiddlewareService { + registration: SupervisorMiddlewareService, + binding_ids: Vec, +} + +impl Default for MiddlewareRegistry { + fn default() -> Self { + Self { + services: Arc::new(vec![Arc::clone(&IN_PROCESS_SERVICE)]), + registered_services: Arc::new(Vec::new()), + } + } +} + +fn validate_registration(registration: &SupervisorMiddlewareService) -> Result<()> { + if registration.name.trim().is_empty() { + return Err(miette!( + "supervisor middleware registration name cannot be empty" + )); + } + if !registration.allow_insecure { + return Err(miette!( + "middleware registration '{}' must set allow_insecure = true; TLS is not supported in V1", + registration.name + )); + } + if !registration.endpoint.starts_with("http://") { + return Err(miette!( + "middleware registration '{}' endpoint must use http:// in the local-development-only V1", + registration.name + )); + } + if registration.max_body_bytes == 0 { + return Err(miette!( + "middleware registration '{}' max_body_bytes must be greater than zero", + registration.name + )); + } + Ok(()) +} + +fn validate_external_manifest( + registration: &SupervisorMiddlewareService, + manifest: &MiddlewareManifest, + operator_max_body_bytes: usize, + known_binding_ids: &mut HashSet, +) -> Result> { + if manifest.api_version != API_VERSION { + return Err(miette!( + "middleware registration '{}' reports unsupported API version '{}'", + registration.name, + manifest.api_version + )); + } + if manifest.bindings.is_empty() { + return Err(miette!( + "middleware registration '{}' describes no bindings", + registration.name + )); + } + + let mut described_ids = Vec::with_capacity(manifest.bindings.len()); + for binding in &manifest.bindings { + if binding.id.trim().is_empty() { + return Err(miette!( + "middleware registration '{}' describes an empty binding id", + registration.name + )); + } + if binding.id.starts_with("openshell/") { + return Err(miette!( + "external middleware registration '{}' cannot claim reserved binding '{}'", + registration.name, + binding.id + )); + } + if binding.operation != HTTP_REQUEST_OPERATION || binding.phase != PRE_CREDENTIALS_PHASE { + return Err(miette!( + "middleware binding '{}' must support {HTTP_REQUEST_OPERATION}/{PRE_CREDENTIALS_PHASE}", + binding.id + )); + } + let advertised = usize::try_from(binding.max_body_bytes).map_err(|_| { + miette!( + "middleware binding '{}' reports a body limit too large for this platform", + binding.id + ) + })?; + if advertised == 0 { + return Err(miette!( + "middleware binding '{}' must advertise a non-zero body limit", + binding.id + )); + } + if operator_max_body_bytes > advertised { + return Err(miette!( + "middleware registration '{}' max_body_bytes ({operator_max_body_bytes}) exceeds binding '{}' capability ({advertised})", + registration.name, + binding.id + )); + } + if !known_binding_ids.insert(binding.id.clone()) { + return Err(miette!( + "middleware binding '{}' is described by more than one service", + binding.id + )); + } + described_ids.push(binding.id.clone()); + } + Ok(described_ids) +} + +impl MiddlewareRegistry { + /// Connect and validate every operator-provided service registration. + pub async fn connect_services(registrations: Vec) -> Result { + let mut services = vec![Arc::clone(&IN_PROCESS_SERVICE)]; + let mut registered_services = Vec::with_capacity(registrations.len()); + let mut registration_names = HashSet::new(); + let mut binding_ids = HashSet::from([BUILTIN_SECRETS.to_string()]); + + for registration in registrations { + validate_registration(®istration)?; + if !registration_names.insert(registration.name.clone()) { + return Err(miette!( + "duplicate supervisor middleware registration name '{}'", + registration.name + )); + } + + let operator_max_body_bytes = + usize::try_from(registration.max_body_bytes).map_err(|_| { + miette!( + "middleware registration '{}' body limit is too large for this platform", + registration.name + ) + })?; + let service = Arc::new( + remote::RemoteMiddlewareService::connect( + ®istration.name, + ®istration.endpoint, + ) + .await?, + ); + let manifest = service + .describe(Request::new(())) + .await + .map(tonic::Response::into_inner) + .map_err(|error| { + miette!( + "middleware registration '{}' Describe failed: {}", + registration.name, + safe_reason(&error.to_string()) + ) + })?; + let described_ids = validate_external_manifest( + ®istration, + &manifest, + operator_max_body_bytes, + &mut binding_ids, + )?; + let manifest_cell = OnceCell::new(); + manifest_cell + .set(manifest) + .map_err(|_| miette!("middleware manifest cache initialized twice"))?; + services.push(Arc::new(MiddlewareServiceState { + service, + manifest: manifest_cell, + operator_max_body_bytes: Some(operator_max_body_bytes), + })); + registered_services.push(RegisteredMiddlewareService { + registration, + binding_ids: described_ids, + }); + } + + Ok(Self { + services: Arc::new(services), + registered_services: Arc::new(registered_services), + }) + } + + /// Validate implementation-owned configuration for every middleware entry. + pub async fn validate_policy_configs(&self, policy: &SandboxPolicy) -> Result<()> { + let runner = ChainRunner::from_registry(self.clone()); + for config in &policy.network_middlewares { + runner + .validate_config( + &config.middleware, + config.config.clone().unwrap_or_default(), + ) + .await + .map_err(|error| { + miette!( + "middleware config '{}' is invalid: {}", + config.name, + safe_reason(&error.to_string()) + ) + })?; + } + Ok(()) + } + + /// Check that every policy binding still belongs to the current static + /// registry without making a network call. + pub fn ensure_policy_bindings_registered(&self, policy: &SandboxPolicy) -> Result<()> { + for config in &policy.network_middlewares { + let registered = config.middleware == BUILTIN_SECRETS + || self.registered_services.iter().any(|service| { + service + .binding_ids + .iter() + .any(|binding| binding == &config.middleware) + }); + if !registered { + return Err(miette!( + "middleware binding '{}' used by config '{}' is not registered", + config.middleware, + config.name + )); + } + } + Ok(()) + } + + /// Return only operator-registered services referenced by the effective policy. + pub fn required_services( + &self, + policy: Option<&SandboxPolicy>, + ) -> Vec { + let Some(policy) = policy else { + return Vec::new(); + }; + let selected: HashSet<&str> = policy + .network_middlewares + .iter() + .map(|config| config.middleware.as_str()) + .collect(); + self.registered_services + .iter() + .filter(|service| { + service + .binding_ids + .iter() + .any(|binding| selected.contains(binding.as_str())) + }) + .map(|service| service.registration.clone()) + .collect() + } +} + +impl Default for ChainRunner { + fn default() -> Self { + Self::from_registry(MiddlewareRegistry::default()) + } +} + +impl ChainRunner { + pub fn new(service: Arc) -> Self { + Self { + registry: Arc::new(MiddlewareRegistry { + services: Arc::new(vec![Arc::new(MiddlewareServiceState { + service, + manifest: OnceCell::new(), + operator_max_body_bytes: None, + })]), + registered_services: Arc::new(Vec::new()), + }), + } + } + + pub fn from_registry(registry: MiddlewareRegistry) -> Self { + Self { + registry: Arc::new(registry), + } + } + + async fn manifests(&self) -> Result, MiddlewareManifest)>> { + let mut manifests = Vec::with_capacity(self.registry.services.len()); + for state in self.registry.services.iter() { + let manifest = state + .manifest + .get_or_try_init(|| async { + state + .service + .describe(Request::new(())) + .await + .map(tonic::Response::into_inner) + .map_err(|error| { + miette!( + "middleware Describe failed: {}", + safe_reason(&error.to_string()) + ) + }) + }) + .await?; + manifests.push((Arc::clone(state), manifest.clone())); + } + Ok(manifests) + } + + pub async fn describe_chain(&self, entries: &[ChainEntry]) -> Result> { + let manifests = self.manifests().await?; + entries + .iter() + .map(|entry| { + let described = manifests.iter().find_map(|(state, manifest)| { + manifest + .bindings + .iter() + .find(|binding| binding.id == entry.implementation) + .cloned() + .map(|binding| (Arc::clone(state), binding)) + }); + let (service, binding) = described.map_or((None, None), |(service, binding)| { + (Some(service), Some(binding)) + }); + let max_body_bytes = binding + .as_ref() + .map(|binding| { + let advertised = usize::try_from(binding.max_body_bytes).map_err(|_| { + miette!( + "middleware binding '{}' reports a body limit too large for this platform", + binding.id + ) + })?; + Ok::<_, miette::Report>(service + .as_ref() + .and_then(|state| state.operator_max_body_bytes) + .unwrap_or(advertised)) + }) + .transpose()? + .unwrap_or(0); + Ok(DescribedChainEntry { + entry: entry.clone(), + service, + binding, + max_body_bytes, + }) + }) + .collect() + } + + pub async fn validate_config( + &self, + implementation: &str, + config: prost_types::Struct, + ) -> Result<()> { + let manifests = self.manifests().await?; + let Some((state, _)) = manifests.iter().find(|(_, manifest)| { + manifest + .bindings + .iter() + .any(|binding| binding.id == implementation) + }) else { + return Err(miette!( + "middleware binding '{implementation}' is not registered" + )); + }; + let response = state + .service + .validate_config(Request::new(ValidateConfigRequest { + api_version: API_VERSION.into(), + binding_id: implementation.into(), + config: Some(config), + })) + .await + .map(tonic::Response::into_inner) + .map_err(|error| { + miette!( + "middleware ValidateConfig failed: {}", + safe_reason(&error.to_string()) + ) + })?; + if response.valid { + Ok(()) + } else { + Err(miette!("{}", safe_reason(&response.reason))) + } + } + + pub async fn evaluate( + &self, + entries: &[ChainEntry], + input: HttpRequestInput, + ) -> Result { + let entries = self.describe_chain(entries).await?; + self.evaluate_described(&entries, input).await + } + + pub async fn evaluate_described( + &self, + entries: &[DescribedChainEntry], + input: HttpRequestInput, + ) -> Result { + let mut headers = input.headers.clone(); + let mut body = input.body.clone(); + let mut added_headers = BTreeMap::new(); + let mut findings = Vec::new(); + let mut metadata = BTreeMap::new(); + let mut applied = Vec::new(); + + for entry in entries { + let Some(binding) = entry.binding.as_ref() else { + match apply_on_error(entry, "binding_not_described", &mut applied) { + OnErrorAction::FailOpen => continue, + OnErrorAction::FailClosed(reason) => { + return Ok(ChainOutcome { + allowed: false, + reason, + body, + added_headers, + findings, + metadata, + applied, + }); + } + } + }; + if body.len() > entry.max_body_bytes { + match apply_on_error(entry, "request_body_over_capacity", &mut applied) { + OnErrorAction::FailOpen => continue, + OnErrorAction::FailClosed(reason) => { + return Ok(ChainOutcome { + allowed: false, + reason, + body, + added_headers, + findings, + metadata, + applied, + }); + } + } + } + let evaluation = build_evaluation(entry, binding, &input, &headers, &body); + let Some(service) = entry.service.as_ref() else { + unreachable!("described binding always has a service") + }; + let result = match service + .service + .evaluate_http_request(Request::new(evaluation)) + .await + { + Ok(result) => result.into_inner(), + Err(err) => { + match apply_on_error(entry, &safe_reason(&err.to_string()), &mut applied) { + OnErrorAction::FailOpen => continue, + OnErrorAction::FailClosed(reason) => { + return Ok(ChainOutcome { + allowed: false, + reason, + body, + added_headers, + findings, + metadata, + applied, + }); + } + } + } + }; + + let decision = match Decision::try_from(result.decision) { + Ok(decision @ (Decision::Allow | Decision::Deny)) => decision, + Ok(Decision::Unspecified) | Err(_) => { + match apply_on_error(entry, "invalid_response_decision", &mut applied) { + OnErrorAction::FailOpen => continue, + OnErrorAction::FailClosed(reason) => { + return Ok(ChainOutcome { + allowed: false, + reason, + body, + added_headers, + findings, + metadata, + applied, + }); + } + } + } + }; + + if result.has_body && result.body.len() > entry.max_body_bytes { + match apply_on_error(entry, "response_body_over_capacity", &mut applied) { + OnErrorAction::FailOpen => continue, + OnErrorAction::FailClosed(reason) => { + return Ok(ChainOutcome { + allowed: false, + reason, + body, + added_headers, + findings, + metadata, + applied, + }); + } + } + } + + // A result proposing unsafe header mutations is a malformed response: + // route it through `on_error` instead of applying any of it. + if validate_header_mutations(&headers, &result.add_headers).is_err() { + match apply_on_error(entry, "unsafe_response_headers", &mut applied) { + OnErrorAction::FailOpen => continue, + OnErrorAction::FailClosed(reason) => { + return Ok(ChainOutcome { + allowed: false, + reason, + body, + added_headers, + findings, + metadata, + applied, + }); + } + } + } + for (name, value) in &result.add_headers { + headers.insert(name.to_ascii_lowercase(), value.clone()); + added_headers.insert(name.to_ascii_lowercase(), value.clone()); + } + let transformed = result.has_body; + if result.has_body { + result.body.clone_into(&mut body); + } + for finding in result.findings { + findings.push(NamespacedFinding { + middleware: entry.entry.name.clone(), + finding, + }); + } + if !result.metadata.is_empty() { + metadata.insert( + entry.entry.name.clone(), + result.metadata.clone().into_iter().collect(), + ); + } + applied.push(MiddlewareInvocation { + name: entry.entry.name.clone(), + implementation: entry.entry.implementation.clone(), + decision, + transformed, + failed: false, + }); + if decision == Decision::Deny { + return Ok(ChainOutcome { + allowed: false, + reason: safe_reason(&result.reason), + body, + added_headers, + findings, + metadata, + applied, + }); + } + } + + Ok(ChainOutcome { + allowed: true, + reason: String::new(), + body, + added_headers, + findings, + metadata, + applied, + }) + } +} + +fn build_evaluation( + entry: &DescribedChainEntry, + binding: &MiddlewareBinding, + input: &HttpRequestInput, + headers: &BTreeMap, + body: &[u8], +) -> HttpRequestEvaluation { + HttpRequestEvaluation { + api_version: API_VERSION.into(), + binding_id: binding.id.clone(), + phase: binding.phase.clone(), + context: Some(RequestContext { + request_id: input.request_id.clone(), + sandbox_id: input.sandbox_id.clone(), + originating_process: None, + }), + config: Some(entry.entry.config.clone()), + target: Some(HttpRequestTarget { + scheme: input.scheme.clone(), + host: input.host.clone(), + port: u32::from(input.port), + method: input.method.clone(), + path: input.path.clone(), + query: input.query.clone(), + }), + headers: headers.clone().into_iter().collect(), + body: body.to_vec(), + } +} + +fn validate_header_mutations( + existing_headers: &BTreeMap, + mutations: &HashMap, +) -> Result<()> { + let mut seen = HashSet::new(); + for (name, value) in mutations { + let lower = name.to_ascii_lowercase(); + if !seen.insert(lower.clone()) || existing_headers.contains_key(&lower) { + return Err(miette!( + "middleware cannot rewrite existing header '{name}'" + )); + } + if !is_safe_append_header(&lower) { + return Err(miette!("middleware cannot append unsafe header '{name}'")); + } + // Reject CR/LF and other control characters in the value: writing them + // verbatim into the upstream header block would enable header injection + // and request smuggling past the credential boundary. + if !is_safe_header_value(value) { + return Err(miette!( + "middleware cannot append header '{name}' with an unsafe value" + )); + } + } + Ok(()) +} + +/// A header value is safe to append only if it contains no control characters. +/// Horizontal tab, printable ASCII, and obs-text (>= 0x80) are permitted; CR, LF, +/// NUL, and other control bytes are rejected. +fn is_safe_header_value(value: &str) -> bool { + value + .bytes() + .all(|b| b == b'\t' || (0x20..=0x7e).contains(&b) || b >= 0x80) +} + +fn is_safe_append_header(name: &str) -> bool { + if name.is_empty() + || name.contains(':') + || name.bytes().any(|b| b <= 0x20 || b >= 0x7f) + || matches!( + name, + "authorization" | "cookie" | "host" | "content-length" | "transfer-encoding" + ) + || name.starts_with("x-amz-") + || name.starts_with("x-openshell-credential") + { + return false; + } + name.starts_with("x-openshell-middleware-") +} + +pub(crate) fn safe_reason(reason: &str) -> String { + reason + .chars() + .filter(|ch| ch.is_ascii_alphanumeric() || matches!(ch, '_' | '-' | ':' | ' ')) + .take(160) + .collect() +} + +#[cfg(test)] +mod tests { + use super::*; + use openshell_core::proto::middleware::v1::supervisor_middleware_server::{ + SupervisorMiddleware, SupervisorMiddlewareServer, + }; + use tokio_stream::wrappers::TcpListenerStream; + + fn entry(name: &str, on_error: OnError) -> ChainEntry { + ChainEntry { + name: name.into(), + implementation: BUILTIN_SECRETS.into(), + config: prost_types::Struct { + fields: std::iter::once(( + "secrets".into(), + prost_types::Value { + kind: Some(prost_types::value::Kind::StringValue("redact".into())), + }, + )) + .collect(), + }, + on_error, + } + } + + fn input(body: &str) -> HttpRequestInput { + HttpRequestInput { + request_id: "req".into(), + sandbox_id: "sbx".into(), + scheme: "https".into(), + host: "api.example.com".into(), + port: 443, + method: "POST".into(), + path: "/v1".into(), + query: String::new(), + headers: BTreeMap::new(), + body: body.as_bytes().to_vec(), + } + } + + #[tokio::test] + async fn phase_one_evaluation_omits_originating_process() { + let entries = ChainRunner::default() + .describe_chain(&[entry("redact", OnError::FailClosed)]) + .await + .expect("describe chain"); + let entry = &entries[0]; + let binding = entry.binding.as_ref().expect("described binding"); + let input = input("payload"); + let evaluation = build_evaluation(entry, binding, &input, &BTreeMap::new(), b"payload"); + + assert!( + evaluation + .context + .expect("request context") + .originating_process + .is_none() + ); + } + + #[tokio::test] + async fn redacts_common_secret_patterns() { + let outcome = ChainRunner::default() + .evaluate( + &[entry("redact", OnError::FailClosed)], + input(r#"{"api_key":"sk-1234567890abcdef"}"#), + ) + .await + .expect("evaluate"); + assert!(outcome.allowed); + assert_eq!( + String::from_utf8(outcome.body).expect("utf8"), + r#"{"api_key":"[REDACTED]"}"# + ); + assert_eq!(outcome.findings[0].finding.count, 1); + } + + #[tokio::test] + async fn transformed_body_feeds_next_stage() { + let entries = [ + entry("first", OnError::FailClosed), + entry("second", OnError::FailClosed), + ]; + let outcome = ChainRunner::default() + .evaluate(&entries, input(r#"password="top-secret""#)) + .await + .expect("evaluate"); + assert!(outcome.allowed); + assert_eq!( + String::from_utf8(outcome.body).expect("utf8"), + r#"password="[REDACTED]""# + ); + assert_eq!(outcome.applied.len(), 2); + } + + #[tokio::test] + async fn fail_open_allows_unavailable_middleware() { + let unavailable = ChainEntry { + name: "missing".into(), + implementation: "third-party/missing".into(), + config: prost_types::Struct::default(), + on_error: OnError::FailOpen, + }; + let outcome = ChainRunner::default() + .evaluate(&[unavailable], input("hello")) + .await + .expect("evaluate"); + assert!(outcome.allowed); + assert_eq!(outcome.body, b"hello"); + } + + #[tokio::test] + async fn fail_closed_denies_unavailable_middleware() { + let unavailable = ChainEntry { + name: "missing".into(), + implementation: "third-party/missing".into(), + config: prost_types::Struct::default(), + on_error: OnError::FailClosed, + }; + let outcome = ChainRunner::default() + .evaluate(&[unavailable], input("hello")) + .await + .expect("evaluate"); + assert!(!outcome.allowed); + assert!(outcome.reason.starts_with("middleware_failed:")); + } + + #[tokio::test] + async fn in_process_service_describes_builtin_binding() { + let manifest = InProcessMiddlewareService + .describe(Request::new(())) + .await + .expect("describe") + .into_inner(); + assert_eq!(manifest.api_version, API_VERSION); + assert_eq!(manifest.bindings[0].id, BUILTIN_SECRETS); + assert_eq!(manifest.bindings[0].operation, "HttpRequest"); + assert_eq!(manifest.bindings[0].phase, "pre_credentials"); + assert_eq!(manifest.bindings[0].max_body_bytes, 256 * 1024); + } + + #[test] + fn unsafe_header_mutation_is_rejected() { + let err = validate_header_mutations( + &BTreeMap::new(), + &std::iter::once(("Authorization".into(), "Bearer nope".into())).collect(), + ) + .expect_err("unsafe header"); + assert!(err.to_string().contains("unsafe header")); + } + + #[test] + fn header_value_with_crlf_is_rejected() { + // A safe header *name* with a CRLF-bearing value must still be rejected, + // otherwise it would inject extra headers into the upstream request. + let err = validate_header_mutations( + &BTreeMap::new(), + &std::iter::once(( + "x-openshell-middleware-inject".into(), + "ok\r\nAuthorization: Bearer evil".into(), + )) + .collect(), + ) + .expect_err("crlf value"); + assert!(err.to_string().contains("unsafe value")); + } + + /// A mock middleware that returns a fixed, caller-supplied result for every + /// evaluation. Used to exercise chain behavior the built-in cannot produce + /// (explicit deny, metadata, findings, unsafe header mutations). + struct ScriptedService { + binding_id: String, + max_body_bytes: u64, + result: openshell_core::proto::HttpRequestResult, + } + + #[tonic::async_trait] + impl SupervisorMiddleware for ScriptedService { + async fn describe( + &self, + _request: Request<()>, + ) -> std::result::Result, tonic::Status> { + Ok(tonic::Response::new(MiddlewareManifest { + api_version: API_VERSION.into(), + name: "test/middleware".into(), + service_version: "test".into(), + bindings: vec![MiddlewareBinding { + id: self.binding_id.clone(), + operation: "HttpRequest".into(), + phase: "pre_credentials".into(), + max_body_bytes: self.max_body_bytes, + }], + })) + } + + async fn validate_config( + &self, + _request: Request, + ) -> std::result::Result< + tonic::Response, + tonic::Status, + > { + Ok(tonic::Response::new( + openshell_core::proto::ValidateConfigResponse { + valid: true, + reason: String::new(), + }, + )) + } + + async fn evaluate_http_request( + &self, + _request: Request, + ) -> std::result::Result< + tonic::Response, + tonic::Status, + > { + Ok(tonic::Response::new(self.result.clone())) + } + } + + fn scripted_service(result: openshell_core::proto::HttpRequestResult) -> ScriptedService { + ScriptedService { + binding_id: BUILTIN_SECRETS.into(), + max_body_bytes: 256 * 1024, + result, + } + } + + fn allow_result() -> openshell_core::proto::HttpRequestResult { + openshell_core::proto::HttpRequestResult { + decision: Decision::Allow as i32, + reason: String::new(), + body: Vec::new(), + has_body: false, + add_headers: HashMap::new(), + findings: Vec::new(), + metadata: HashMap::new(), + } + } + + fn external_registration(max_body_bytes: u64) -> SupervisorMiddlewareService { + SupervisorMiddlewareService { + name: "local-guard-service".into(), + endpoint: "http://127.0.0.1:50051".into(), + allow_insecure: true, + max_body_bytes, + } + } + + async fn registry_with_external( + service: Arc, + registration: SupervisorMiddlewareService, + ) -> MiddlewareRegistry { + let manifest = service + .describe(Request::new(())) + .await + .expect("describe test service") + .into_inner(); + let operator_max_body_bytes = usize::try_from(registration.max_body_bytes).unwrap(); + let mut known = HashSet::from([BUILTIN_SECRETS.to_string()]); + let binding_ids = validate_external_manifest( + ®istration, + &manifest, + operator_max_body_bytes, + &mut known, + ) + .expect("valid external manifest"); + let manifest_cell = OnceCell::new(); + manifest_cell.set(manifest).expect("manifest cache"); + MiddlewareRegistry { + services: Arc::new(vec![ + Arc::clone(&IN_PROCESS_SERVICE), + Arc::new(MiddlewareServiceState { + service, + manifest: manifest_cell, + operator_max_body_bytes: Some(operator_max_body_bytes), + }), + ]), + registered_services: Arc::new(vec![RegisteredMiddlewareService { + registration, + binding_ids, + }]), + } + } + + #[tokio::test] + async fn descriptors_are_resolved_from_any_middleware_service() { + let runner = ChainRunner::new(Arc::new(ScriptedService { + binding_id: "example/redactor".into(), + max_body_bytes: 4096, + result: allow_result(), + })); + let entry = ChainEntry { + name: "external".into(), + implementation: "example/redactor".into(), + config: prost_types::Struct::default(), + on_error: OnError::FailClosed, + }; + + let described = runner + .describe_chain(std::slice::from_ref(&entry)) + .await + .expect("describe external middleware"); + assert_eq!(described[0].max_body_bytes(), 4096); + assert_eq!( + described[0] + .binding + .as_ref() + .expect("described binding") + .phase, + "pre_credentials" + ); + + let outcome = runner + .evaluate_described(&described, input("hello")) + .await + .expect("evaluate external middleware"); + assert!(outcome.allowed); + } + + #[tokio::test] + async fn mixed_builtin_and_external_chain_uses_operator_limit() { + let external = Arc::new(ScriptedService { + binding_id: "example/content-guard".into(), + max_body_bytes: 4096, + result: allow_result(), + }); + let registry = registry_with_external(external, external_registration(1024)).await; + let runner = ChainRunner::from_registry(registry); + let external_entry = ChainEntry { + name: "external".into(), + implementation: "example/content-guard".into(), + config: prost_types::Struct::default(), + on_error: OnError::FailClosed, + }; + let entries = [entry("builtin", OnError::FailClosed), external_entry]; + + let described = runner + .describe_chain(&entries) + .await + .expect("describe chain"); + assert_eq!(described[0].max_body_bytes(), 256 * 1024); + assert_eq!(described[1].max_body_bytes(), 1024); + + let outcome = runner + .evaluate_described(&described, input(r#"password="top-secret""#)) + .await + .expect("evaluate mixed chain"); + assert!(outcome.allowed); + assert_eq!(outcome.applied.len(), 2); + assert_eq!( + String::from_utf8(outcome.body).expect("utf8"), + r#"password="[REDACTED]""# + ); + } + + #[test] + fn external_manifest_rejects_operator_limit_above_capability() { + let registration = external_registration(4097); + let manifest = MiddlewareManifest { + api_version: API_VERSION.into(), + name: "example/service".into(), + service_version: "test".into(), + bindings: vec![MiddlewareBinding { + id: "example/content-guard".into(), + operation: HTTP_REQUEST_OPERATION.into(), + phase: PRE_CREDENTIALS_PHASE.into(), + max_body_bytes: 4096, + }], + }; + let error = validate_external_manifest( + ®istration, + &manifest, + 4097, + &mut HashSet::from([BUILTIN_SECRETS.to_string()]), + ) + .expect_err("operator limit must fit capability"); + assert!(error.to_string().contains("exceeds")); + } + + #[test] + fn external_registration_requires_explicit_insecure_opt_in() { + let mut registration = external_registration(4096); + registration.allow_insecure = false; + let error = validate_registration(®istration).expect_err("opt-in required"); + assert!(error.to_string().contains("allow_insecure")); + } + + #[tokio::test] + async fn external_registry_invokes_remote_service_over_grpc() { + let listener = tokio::net::TcpListener::bind("127.0.0.1:0") + .await + .expect("bind test middleware"); + let address = listener.local_addr().expect("test middleware address"); + let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel(); + let server = tonic::transport::Server::builder() + .add_service(SupervisorMiddlewareServer::new(ScriptedService { + binding_id: "example/content-guard".into(), + max_body_bytes: 4096, + result: allow_result(), + })) + .serve_with_incoming_shutdown(TcpListenerStream::new(listener), async { + let _ = shutdown_rx.await; + }); + let server_task = tokio::spawn(server); + + let mut registration = external_registration(1024); + registration.endpoint = format!("http://{address}"); + let registry = MiddlewareRegistry::connect_services(vec![registration.clone()]) + .await + .expect("connect external middleware"); + let policy = SandboxPolicy { + network_middlewares: vec![NetworkMiddlewareConfig { + name: "guard".into(), + middleware: "example/content-guard".into(), + config: Some(prost_types::Struct::default()), + on_error: "fail_closed".into(), + endpoints: None, + }], + ..Default::default() + }; + + registry + .validate_policy_configs(&policy) + .await + .expect("remote config validates"); + assert_eq!( + registry.required_services(Some(&policy)), + vec![registration] + ); + + let outcome = ChainRunner::from_registry(registry) + .evaluate( + &[ChainEntry { + name: "guard".into(), + implementation: "example/content-guard".into(), + config: prost_types::Struct::default(), + on_error: OnError::FailClosed, + }], + input("hello"), + ) + .await + .expect("remote evaluation"); + assert!(outcome.allowed); + + let _ = shutdown_tx.send(()); + server_task + .await + .expect("join test middleware") + .expect("serve"); + } + + #[tokio::test] + async fn deny_decision_short_circuits_chain() { + let runner = ChainRunner::new(Arc::new(scripted_service( + openshell_core::proto::HttpRequestResult { + decision: Decision::Deny as i32, + reason: "blocked_by_policy".into(), + ..allow_result() + }, + ))); + let outcome = runner + .evaluate( + &[ + entry("first", OnError::FailClosed), + entry("second", OnError::FailClosed), + ], + input("hello"), + ) + .await + .expect("evaluate"); + assert!(!outcome.allowed); + assert_eq!(outcome.reason, "blocked_by_policy"); + // The deny short-circuits the chain: the second middleware never runs. + assert_eq!(outcome.applied.len(), 1); + assert_eq!(outcome.applied[0].decision, Decision::Deny); + assert!(!outcome.applied[0].failed); + } + + #[tokio::test] + async fn metadata_and_findings_are_namespaced_per_config() { + let runner = ChainRunner::new(Arc::new(scripted_service( + openshell_core::proto::HttpRequestResult { + findings: vec![Finding { + r#type: "pii.email".into(), + label: "email address".into(), + count: 2, + confidence: "high".into(), + severity: "medium".into(), + }], + metadata: std::iter::once(("sensitivity".to_string(), "high".to_string())) + .collect(), + ..allow_result() + }, + ))); + let outcome = runner + .evaluate( + &[ + entry("alpha", OnError::FailClosed), + entry("beta", OnError::FailClosed), + ], + input("hello"), + ) + .await + .expect("evaluate"); + assert!(outcome.allowed); + // Metadata is bucketed under each config's local name, so two configs + // emitting the same key do not collide. + assert_eq!(outcome.metadata["alpha"]["sensitivity"], "high"); + assert_eq!(outcome.metadata["beta"]["sensitivity"], "high"); + // Findings are tagged with the emitting config's name. + assert_eq!(outcome.findings.len(), 2); + assert_eq!(outcome.findings[0].middleware, "alpha"); + assert_eq!(outcome.findings[1].middleware, "beta"); + assert_eq!(outcome.findings[0].finding.r#type, "pii.email"); + assert_eq!(outcome.findings[0].finding.count, 2); + } + + fn unsafe_header_service() -> ScriptedService { + scripted_service(openshell_core::proto::HttpRequestResult { + add_headers: std::iter::once(( + "x-openshell-middleware-inject".to_string(), + "ok\r\nHost: evil".to_string(), + )) + .collect(), + ..allow_result() + }) + } + + #[tokio::test] + async fn malformed_response_headers_fail_closed_denies() { + let runner = ChainRunner::new(Arc::new(unsafe_header_service())); + let outcome = runner + .evaluate(&[entry("redact", OnError::FailClosed)], input("hello")) + .await + .expect("evaluate"); + assert!(!outcome.allowed); + assert!(outcome.reason.starts_with("middleware_failed:")); + assert!(outcome.applied.iter().any(|inv| inv.failed)); + // The unsafe header is never forwarded. + assert!(outcome.added_headers.is_empty()); + } + + #[tokio::test] + async fn malformed_response_headers_fail_open_continues() { + let runner = ChainRunner::new(Arc::new(unsafe_header_service())); + let outcome = runner + .evaluate(&[entry("redact", OnError::FailOpen)], input("hello")) + .await + .expect("evaluate"); + assert!(outcome.allowed); + assert_eq!(outcome.body, b"hello"); + assert!(outcome.added_headers.is_empty()); + assert_eq!(outcome.applied.len(), 1); + assert!(outcome.applied[0].failed); + } + + #[tokio::test] + async fn oversized_replacement_body_honors_on_error() { + let runner = ChainRunner::new(Arc::new(ScriptedService { + binding_id: BUILTIN_SECRETS.into(), + max_body_bytes: 4, + result: openshell_core::proto::HttpRequestResult { + body: b"too large".to_vec(), + has_body: true, + ..allow_result() + }, + })); + let fail_open = entry("small", OnError::FailOpen); + let mut fail_closed = fail_open.clone(); + fail_closed.on_error = OnError::FailClosed; + + let open_outcome = runner + .evaluate(&[fail_open], input("safe")) + .await + .expect("fail-open evaluation"); + assert!(open_outcome.allowed); + assert_eq!(open_outcome.body, b"safe"); + assert!(open_outcome.applied[0].failed); + + let closed_outcome = runner + .evaluate(&[fail_closed], input("safe")) + .await + .expect("fail-closed evaluation"); + assert!(!closed_outcome.allowed); + assert_eq!( + closed_outcome.reason, + "middleware_failed: response_body_over_capacity" + ); + assert!(closed_outcome.applied[0].failed); + } + + #[tokio::test] + async fn oversized_request_body_honors_on_error() { + let runner = ChainRunner::new(Arc::new(ScriptedService { + binding_id: BUILTIN_SECRETS.into(), + max_body_bytes: 4, + result: allow_result(), + })); + let fail_open = entry("small", OnError::FailOpen); + let mut fail_closed = fail_open.clone(); + fail_closed.on_error = OnError::FailClosed; + + let open_outcome = runner + .evaluate(&[fail_open], input("hello")) + .await + .expect("fail-open evaluation"); + assert!(open_outcome.allowed); + assert_eq!(open_outcome.body, b"hello"); + assert!(open_outcome.applied[0].failed); + + let closed_outcome = runner + .evaluate(&[fail_closed], input("hello")) + .await + .expect("fail-closed evaluation"); + assert!(!closed_outcome.allowed); + assert_eq!( + closed_outcome.reason, + "middleware_failed: request_body_over_capacity" + ); + assert!(closed_outcome.applied[0].failed); + } + + #[tokio::test] + async fn unspecified_decision_uses_fail_closed() { + let runner = ChainRunner::new(Arc::new(scripted_service( + openshell_core::proto::HttpRequestResult { + decision: Decision::Unspecified as i32, + ..allow_result() + }, + ))); + + let outcome = runner + .evaluate(&[entry("redact", OnError::FailClosed)], input("hello")) + .await + .expect("evaluate"); + + assert!(!outcome.allowed); + assert_eq!( + outcome.reason, + "middleware_failed: invalid_response_decision" + ); + assert!(outcome.applied[0].failed); + } +} diff --git a/crates/openshell-supervisor-middleware/src/remote.rs b/crates/openshell-supervisor-middleware/src/remote.rs new file mode 100644 index 000000000..dd147788b --- /dev/null +++ b/crates/openshell-supervisor-middleware/src/remote.rs @@ -0,0 +1,91 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use std::time::Duration; + +use miette::{IntoDiagnostic, Result, WrapErr}; +use openshell_core::proto::middleware::v1::supervisor_middleware_client::SupervisorMiddlewareClient; +use openshell_core::proto::middleware::v1::supervisor_middleware_server::SupervisorMiddleware; +use openshell_core::proto::{ + HttpRequestEvaluation, HttpRequestResult, MiddlewareManifest, ValidateConfigRequest, + ValidateConfigResponse, +}; +use tonic::transport::{Channel, Endpoint}; +use tonic::{Request, Response, Status}; + +const CONNECT_TIMEOUT: Duration = Duration::from_secs(5); +const RPC_TIMEOUT: Duration = Duration::from_secs(5); + +#[derive(Clone)] +pub struct RemoteMiddlewareService { + registration_name: String, + client: SupervisorMiddlewareClient, +} + +impl RemoteMiddlewareService { + pub async fn connect(registration_name: &str, endpoint: &str) -> Result { + let channel = Endpoint::from_shared(endpoint.to_string()) + .into_diagnostic() + .wrap_err_with(|| { + format!("middleware registration '{registration_name}' has an invalid endpoint") + })? + .connect_timeout(CONNECT_TIMEOUT) + .connect() + .await + .into_diagnostic() + .wrap_err_with(|| { + format!( + "middleware registration '{registration_name}' could not connect to {endpoint}" + ) + })?; + Ok(Self { + registration_name: registration_name.to_string(), + client: SupervisorMiddlewareClient::new(channel), + }) + } + + async fn with_timeout( + &self, + operation: &'static str, + future: impl Future, Status>>, + ) -> std::result::Result, Status> { + tokio::time::timeout(RPC_TIMEOUT, future) + .await + .map_err(|_| { + Status::deadline_exceeded(format!( + "middleware '{}' {operation} timed out", + self.registration_name + )) + })? + } +} + +#[tonic::async_trait] +impl SupervisorMiddleware for RemoteMiddlewareService { + async fn describe( + &self, + request: Request<()>, + ) -> std::result::Result, Status> { + let mut client = self.client.clone(); + self.with_timeout("Describe", client.describe(request)) + .await + } + + async fn validate_config( + &self, + request: Request, + ) -> std::result::Result, Status> { + let mut client = self.client.clone(); + self.with_timeout("ValidateConfig", client.validate_config(request)) + .await + } + + async fn evaluate_http_request( + &self, + request: Request, + ) -> std::result::Result, Status> { + let mut client = self.client.clone(); + self.with_timeout("EvaluateHttpRequest", client.evaluate_http_request(request)) + .await + } +} diff --git a/crates/openshell-supervisor-middleware/src/service.rs b/crates/openshell-supervisor-middleware/src/service.rs new file mode 100644 index 000000000..51df8d070 --- /dev/null +++ b/crates/openshell-supervisor-middleware/src/service.rs @@ -0,0 +1,58 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use openshell_core::proto::middleware::v1::supervisor_middleware_server::SupervisorMiddleware; +use openshell_core::proto::{ + HttpRequestEvaluation, HttpRequestResult, MiddlewareManifest, ValidateConfigRequest, + ValidateConfigResponse, +}; +use tonic::{Request, Response, Status}; + +use crate::{API_VERSION, builtins, safe_reason, validate_builtin_config}; + +#[derive(Debug, Default)] +pub struct InProcessMiddlewareService; + +#[tonic::async_trait] +impl SupervisorMiddleware for InProcessMiddlewareService { + async fn describe( + &self, + _request: Request<()>, + ) -> Result, Status> { + Ok(Response::new(MiddlewareManifest { + api_version: API_VERSION.into(), + name: "openshell/in-process".into(), + service_version: env!("CARGO_PKG_VERSION").into(), + bindings: builtins::describe(), + })) + } + + async fn validate_config( + &self, + request: Request, + ) -> Result, Status> { + let request = request.into_inner(); + let config = request.config.unwrap_or_default(); + let validation = validate_builtin_config(&request.binding_id, &config); + Ok(Response::new(match validation { + Ok(()) => ValidateConfigResponse { + valid: true, + reason: String::new(), + }, + Err(err) => ValidateConfigResponse { + valid: false, + reason: safe_reason(&err.to_string()), + }, + })) + } + + async fn evaluate_http_request( + &self, + request: Request, + ) -> Result, Status> { + let request = request.into_inner(); + let result = builtins::evaluate_http_request(&request) + .map_err(|err| Status::invalid_argument(safe_reason(&err.to_string())))?; + Ok(Response::new(result)) + } +} diff --git a/crates/openshell-supervisor-network/Cargo.toml b/crates/openshell-supervisor-network/Cargo.toml index 7d0079f7b..b8cae5113 100644 --- a/crates/openshell-supervisor-network/Cargo.toml +++ b/crates/openshell-supervisor-network/Cargo.toml @@ -15,6 +15,7 @@ openshell-core = { path = "../openshell-core" } openshell-ocsf = { path = "../openshell-ocsf" } openshell-policy = { path = "../openshell-policy" } openshell-router = { path = "../openshell-router" } +openshell-supervisor-middleware = { path = "../openshell-supervisor-middleware" } apollo-parser = { workspace = true } aws-sigv4 = { version = "1", features = ["sign-http", "http1"] } @@ -28,6 +29,7 @@ glob = { workspace = true } hex = "0.4" ipnet = "2" miette = { workspace = true } +prost-types = { workspace = true } rcgen = { workspace = true } regorus = { version = "0.9", default-features = false, features = ["std", "arc", "glob"] } reqwest = { workspace = true } @@ -53,6 +55,7 @@ tempfile = "3" temp-env = "0.3" tokio-tungstenite = { workspace = true } futures = { workspace = true } +tracing-subscriber = { workspace = true } [target.'cfg(unix)'.dev-dependencies] libc = "0.2" diff --git a/crates/openshell-supervisor-network/data/sandbox-policy.rego b/crates/openshell-supervisor-network/data/sandbox-policy.rego index efcdf0732..fcc5838e1 100644 --- a/crates/openshell-supervisor-network/data/sandbox-policy.rego +++ b/crates/openshell-supervisor-network/data/sandbox-policy.rego @@ -842,20 +842,20 @@ _policy_endpoint_configs(policy) := [ep | endpoint_has_extended_config(ep) ] -# Collect matching endpoint configs across all policies. Iterates over -# _matching_policy_names (a set, safe from regorus variable collisions) -# then collects per-policy configs via the helper function. _matching_endpoint_configs := [cfg | some pname _matching_policy_names[pname] cfgs := _policy_endpoint_configs(data.network_policies[pname]) cfg := cfgs[_] + endpoint_has_extended_config(cfg) ] matched_endpoint_config := _matching_endpoint_configs[0] if { count(_matching_endpoint_configs) > 0 } +network_middlewares := object.get(data, "network_middlewares", []) + _policy_has_exact_declared_endpoint(policy) if { some ep ep := policy.endpoints[_] @@ -909,7 +909,7 @@ endpoint_path_matches_request(ep, request) if { } # An endpoint has extended config if it specifies L7 protocol, allowed_ips, -# or an explicit tls mode (e.g. tls: skip). +# middleware, or an explicit tls mode (e.g. tls: skip). endpoint_has_extended_config(ep) if { ep.protocol } @@ -918,6 +918,10 @@ endpoint_has_extended_config(ep) if { count(object.get(ep, "allowed_ips", [])) > 0 } +endpoint_has_extended_config(ep) if { + count(object.get(ep, "middleware", [])) > 0 +} + endpoint_has_extended_config(ep) if { ep.tls } diff --git a/crates/openshell-supervisor-network/src/l7/relay.rs b/crates/openshell-supervisor-network/src/l7/relay.rs index ed2bde113..84853f751 100644 --- a/crates/openshell-supervisor-network/src/l7/relay.rs +++ b/crates/openshell-supervisor-network/src/l7/relay.rs @@ -15,9 +15,12 @@ use miette::{IntoDiagnostic, Result, miette}; use openshell_core::activity::{ActivitySender, try_record_activity}; use openshell_core::secrets::{self, SecretResolver}; use openshell_ocsf::{ - ActionId, ActivityId, DispositionId, Endpoint, HttpActivityBuilder, HttpRequest, - NetworkActivityBuilder, SeverityId, StatusId, Url as OcsfUrl, ocsf_emit, + ActionId, ActivityId, DetectionFindingBuilder, DispositionId, Endpoint, FindingInfo, + HttpActivityBuilder, HttpRequest, NetworkActivityBuilder, SeverityId, StatusId, Url as OcsfUrl, + ocsf_emit, }; +use std::collections::BTreeMap; +use std::path::PathBuf; use std::sync::Arc; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tracing::{debug, warn}; @@ -450,6 +453,42 @@ where let _ = &eval_target; if allowed || (config.enforcement == EnforcementMode::Audit && !force_deny) { + let chain = engine.query_middleware_chain(&middleware_network_input(ctx))?; + let req = match apply_middleware_chain( + req, + client, + ctx, + chain, + engine.middleware_runner(), + engine.generation_guard(), + ) + .await? + { + MiddlewareApplyResult::Allowed(req) => req, + MiddlewareApplyResult::Denied(reason) => { + crate::l7::rest::RestProvider::default() + .deny_with_redacted_target( + &crate::l7::provider::L7Request { + action: request_info.action.clone(), + target: redacted_target.clone(), + query_params: request_info.query_params.clone(), + raw_header: Vec::new(), + body_length: crate::l7::provider::BodyLength::None, + }, + &ctx.policy_name, + &reason, + client, + Some(&redacted_target), + Some(crate::l7::rest::DenyResponseContext { + host: Some(&ctx.host), + port: Some(ctx.port), + binary: Some(&ctx.binary_path), + }), + ) + .await?; + return Ok(()); + } + }; let outcome = crate::l7::rest::relay_http_request_with_options_guarded( &req, client, @@ -734,6 +773,292 @@ fn jsonrpc_engine_type(protocol: L7Protocol) -> &'static str { } } +pub(crate) enum MiddlewareApplyResult { + Allowed(crate::l7::provider::L7Request), + Denied(String), +} + +fn middleware_chain_body_limit( + chain: &[openshell_supervisor_middleware::DescribedChainEntry], +) -> Option { + chain + .iter() + .map(openshell_supervisor_middleware::DescribedChainEntry::max_body_bytes) + .min() +} + +pub(crate) async fn apply_middleware_chain( + req: crate::l7::provider::L7Request, + client: &mut C, + ctx: &L7EvalContext, + chain: Vec, + runner: &openshell_supervisor_middleware::ChainRunner, + generation_guard: &PolicyGenerationGuard, +) -> Result { + apply_middleware_chain_for_scheme(req, client, ctx, "https", chain, runner, generation_guard) + .await +} + +pub(crate) async fn apply_middleware_chain_for_scheme( + req: crate::l7::provider::L7Request, + client: &mut C, + ctx: &L7EvalContext, + scheme: &str, + chain: Vec, + runner: &openshell_supervisor_middleware::ChainRunner, + generation_guard: &PolicyGenerationGuard, +) -> Result { + if chain.is_empty() { + return Ok(MiddlewareApplyResult::Allowed(req)); + } + let chain = runner.describe_chain(&chain).await?; + let max_body_bytes = + middleware_chain_body_limit(&chain).expect("non-empty middleware chain has a body limit"); + let buffered = match crate::l7::rest::buffer_request_body_for_middleware( + &req, + client, + Some(generation_guard), + max_body_bytes, + ) + .await? + { + crate::l7::rest::BufferResult::Buffered(buffered) => buffered, + crate::l7::rest::BufferResult::OverCapacity { recoverable } => { + return Ok(resolve_unbuffered_body(ctx, req, &chain, recoverable)); + } + }; + let headers = safe_middleware_headers(&buffered.headers)?; + let query = raw_query_from_request_headers(&buffered.headers)?; + let input = middleware_request_input(scheme, &req, ctx, headers, query, buffered.body); + let outcome = runner.evaluate_described(&chain, input).await?; + emit_middleware_events(ctx, &req, &outcome); + let rebuilt = crate::l7::rest::rebuild_request_with_buffered_body( + &req, + &buffered.headers, + &outcome.body, + &outcome.added_headers, + )?; + if outcome.allowed { + Ok(MiddlewareApplyResult::Allowed(rebuilt)) + } else { + Ok(MiddlewareApplyResult::Denied(outcome.reason)) + } +} + +fn middleware_request_input( + scheme: &str, + req: &crate::l7::provider::L7Request, + ctx: &L7EvalContext, + headers: BTreeMap, + query: String, + body: Vec, +) -> openshell_supervisor_middleware::HttpRequestInput { + openshell_supervisor_middleware::HttpRequestInput { + request_id: uuid::Uuid::new_v4().to_string(), + sandbox_id: openshell_ocsf::ctx::ctx().sandbox_id.clone(), + scheme: scheme.into(), + host: ctx.host.clone(), + port: ctx.port, + method: req.action.clone(), + path: req.target.clone(), + query, + headers, + body, + } +} + +fn raw_query_from_request_headers(headers: &[u8]) -> Result { + let header_str = + std::str::from_utf8(headers).map_err(|_| miette!("HTTP headers contain invalid UTF-8"))?; + let target = header_str + .lines() + .next() + .and_then(|line| line.split_whitespace().nth(1)) + .ok_or_else(|| miette!("HTTP request line is missing a target"))?; + Ok(target + .split_once('?') + .map_or_else(String::new, |(_, query)| query.to_string())) +} + +/// Apply the chain's `on_error` policy when the request body cannot be buffered +/// for inspection because it exceeds the size cap. The RFC treats an unbufferable +/// body as an `on_error` event: it is denied unless every attached middleware is +/// `fail_open`, and passing it through is only safe when no bytes were consumed. +fn resolve_unbuffered_body( + ctx: &L7EvalContext, + req: crate::l7::provider::L7Request, + chain: &[openshell_supervisor_middleware::DescribedChainEntry], + recoverable: bool, +) -> MiddlewareApplyResult { + let all_fail_open = chain + .iter() + .all(|entry| entry.on_error() == openshell_supervisor_middleware::OnError::FailOpen); + if recoverable && all_fail_open { + emit_middleware_body_unavailable(ctx, false); + return MiddlewareApplyResult::Allowed(req); + } + emit_middleware_body_unavailable(ctx, true); + MiddlewareApplyResult::Denied("middleware_failed: request_body_over_capacity".into()) +} + +fn emit_middleware_body_unavailable(ctx: &L7EvalContext, denied: bool) { + let event = DetectionFindingBuilder::new(openshell_ocsf::ctx::ctx()) + .severity(if denied { + SeverityId::High + } else { + SeverityId::Medium + }) + .finding_info(FindingInfo::new( + "openshell.middleware.body_unavailable", + "Supervisor middleware could not inspect request body", + )) + .evidence_pairs(&[ + ("policy", ctx.policy_name.as_str()), + ("host", ctx.host.as_str()), + ("disposition", if denied { "denied" } else { "fail_open" }), + ]) + .message(if denied { + "Request body exceeded middleware inspection cap; denied" + } else { + "Request body exceeded middleware inspection cap; passed through (fail_open)" + }) + .build(); + ocsf_emit!(event); +} + +fn safe_middleware_headers(headers: &[u8]) -> Result> { + let header_str = + std::str::from_utf8(headers).map_err(|_| miette!("HTTP headers contain invalid UTF-8"))?; + let mut out = BTreeMap::new(); + for line in header_str.lines().skip(1) { + let Some((name, value)) = line.split_once(':') else { + continue; + }; + let name = name.trim().to_ascii_lowercase(); + if name.is_empty() + || matches!( + name.as_str(), + "authorization" | "cookie" | "host" | "content-length" | "transfer-encoding" + ) + || name.starts_with("x-amz-") + || name.starts_with("x-openshell-credential") + { + continue; + } + out.insert(name, value.trim().to_string()); + } + Ok(out) +} + +fn middleware_network_input(ctx: &L7EvalContext) -> crate::opa::NetworkInput { + crate::opa::NetworkInput { + host: ctx.host.clone(), + port: ctx.port, + binary_path: PathBuf::from(&ctx.binary_path), + binary_sha256: String::new(), + ancestors: ctx.ancestors.iter().map(PathBuf::from).collect(), + cmdline_paths: ctx.cmdline_paths.iter().map(PathBuf::from).collect(), + } +} + +fn emit_middleware_events( + ctx: &L7EvalContext, + req: &crate::l7::provider::L7Request, + outcome: &openshell_supervisor_middleware::ChainOutcome, +) { + for invocation in &outcome.applied { + let allowed = invocation.decision == openshell_core::proto::Decision::Allow; + let event = HttpActivityBuilder::new(openshell_ocsf::ctx::ctx()) + .activity(ActivityId::Other) + .action(if allowed { + ActionId::Allowed + } else { + ActionId::Denied + }) + .disposition(if allowed { + DispositionId::Allowed + } else { + DispositionId::Blocked + }) + .severity(if allowed { + SeverityId::Informational + } else { + SeverityId::Medium + }) + .http_request(HttpRequest::new( + &req.action, + OcsfUrl::new("http", &ctx.host, &req.target, ctx.port), + )) + .dst_endpoint(Endpoint::from_domain(&ctx.host, ctx.port)) + .firewall_rule(&ctx.policy_name, "middleware") + .message(format!( + "MIDDLEWARE {} {} decision={:?} transformed={} failed={}", + invocation.name, + invocation.implementation, + invocation.decision, + invocation.transformed, + invocation.failed + )) + .build(); + ocsf_emit!(event); + + // A middleware that failed but was bypassed under `fail_open` is an + // enforcement failure operators must be able to alert on, even though the + // request proceeded. + if invocation.failed && allowed { + let event = DetectionFindingBuilder::new(openshell_ocsf::ctx::ctx()) + .severity(SeverityId::Medium) + .finding_info(FindingInfo::new( + "openshell.middleware.failure", + "Supervisor middleware failed open", + )) + .evidence_pairs(&[ + ("middleware", invocation.name.as_str()), + ("implementation", invocation.implementation.as_str()), + ]) + .message(format!( + "Middleware {} failed and was bypassed (fail_open)", + invocation.name + )) + .build(); + ocsf_emit!(event); + } + } + if !outcome.allowed && outcome.reason.starts_with("middleware_failed:") { + let event = DetectionFindingBuilder::new(openshell_ocsf::ctx::ctx()) + .severity(SeverityId::High) + .finding_info(FindingInfo::new( + "openshell.middleware.failure", + "Supervisor middleware failure", + )) + .message("Required supervisor middleware failed closed") + .build(); + ocsf_emit!(event); + } + for finding in &outcome.findings { + let event = DetectionFindingBuilder::new(openshell_ocsf::ctx::ctx()) + .severity(match finding.finding.severity.as_str() { + "high" => SeverityId::High, + "low" => SeverityId::Low, + _ => SeverityId::Medium, + }) + .finding_info(FindingInfo::new( + &finding.finding.r#type, + &finding.finding.label, + )) + .evidence_pairs(&[ + ("middleware", &finding.middleware), + ("count", &finding.finding.count.to_string()), + ]) + .message(format!( + "Middleware finding {} count={}", + finding.finding.r#type, finding.finding.count + )) + .build(); + ocsf_emit!(event); + } +} + /// REST relay loop: parse request -> evaluate -> allow/deny -> relay response -> repeat. async fn relay_rest( config: &L7EndpointConfig, @@ -903,6 +1228,42 @@ where let _ = &eval_target; if allowed || config.enforcement == EnforcementMode::Audit { + let chain = engine.query_middleware_chain(&middleware_network_input(ctx))?; + let req = match apply_middleware_chain( + req, + client, + ctx, + chain, + engine.middleware_runner(), + engine.generation_guard(), + ) + .await? + { + MiddlewareApplyResult::Allowed(req) => req, + MiddlewareApplyResult::Denied(reason) => { + provider + .deny_with_redacted_target( + &crate::l7::provider::L7Request { + action: request_info.action.clone(), + target: redacted_target.clone(), + query_params: request_info.query_params.clone(), + raw_header: Vec::new(), + body_length: crate::l7::provider::BodyLength::None, + }, + &ctx.policy_name, + &reason, + client, + Some(&redacted_target), + Some(crate::l7::rest::DenyResponseContext { + host: Some(&ctx.host), + port: Some(ctx.port), + binary: Some(&ctx.binary_path), + }), + ) + .await?; + return Ok(()); + } + }; let req_with_auth = match crate::l7::token_grant_injection::inject_if_needed(req, ctx).await { Ok(req) => req, @@ -1142,6 +1503,42 @@ where } if allowed || (config.enforcement == EnforcementMode::Audit && !force_deny) { + let chain = engine.query_middleware_chain(&middleware_network_input(ctx))?; + let req = match apply_middleware_chain( + req, + client, + ctx, + chain, + engine.middleware_runner(), + engine.generation_guard(), + ) + .await? + { + MiddlewareApplyResult::Allowed(req) => req, + MiddlewareApplyResult::Denied(reason) => { + crate::l7::rest::RestProvider::default() + .deny_with_redacted_target( + &crate::l7::provider::L7Request { + action: request_info.action.clone(), + target: redacted_target.clone(), + query_params: request_info.query_params.clone(), + raw_header: Vec::new(), + body_length: crate::l7::provider::BodyLength::None, + }, + &ctx.policy_name, + &reason, + client, + Some(&redacted_target), + Some(crate::l7::rest::DenyResponseContext { + host: Some(&ctx.host), + port: Some(ctx.port), + binary: Some(&ctx.binary_path), + }), + ) + .await?; + return Ok(()); + } + }; // Future MCP response/SSE introspection or rewrite would hook here // before returning upstream bytes. The current policy schema has no // trusted-annotations or version-profile field, so MCP responses and @@ -1336,6 +1733,42 @@ where let _ = &eval_target; if allowed || (config.enforcement == EnforcementMode::Audit && !force_deny) { + let chain = engine.query_middleware_chain(&middleware_network_input(ctx))?; + let req = match apply_middleware_chain( + req, + client, + ctx, + chain, + engine.middleware_runner(), + engine.generation_guard(), + ) + .await? + { + MiddlewareApplyResult::Allowed(req) => req, + MiddlewareApplyResult::Denied(reason) => { + crate::l7::rest::RestProvider::default() + .deny_with_redacted_target( + &crate::l7::provider::L7Request { + action: request_info.action.clone(), + target: redacted_target.clone(), + query_params: request_info.query_params.clone(), + raw_header: Vec::new(), + body_length: crate::l7::provider::BodyLength::None, + }, + &ctx.policy_name, + &reason, + client, + Some(&redacted_target), + Some(crate::l7::rest::DenyResponseContext { + host: Some(&ctx.host), + port: Some(ctx.port), + binary: Some(&ctx.binary_path), + }), + ) + .await?; + return Ok(()); + } + }; let outcome = crate::l7::rest::relay_http_request_with_resolver_guarded( &req, client, @@ -1674,6 +2107,7 @@ pub async fn relay_passthrough_with_credentials( upstream: &mut U, ctx: &L7EvalContext, generation_guard: &PolicyGenerationGuard, + middleware_engine: Option<&crate::opa::OpaEngine>, ) -> Result<()> where C: AsyncRead + AsyncWrite + Unpin + Send, @@ -1756,6 +2190,44 @@ where ocsf_emit!(event); } + let req = if let Some(engine) = middleware_engine { + let input = middleware_network_input(ctx); + let (chain, generation) = engine.query_middleware_chain_with_generation(&input)?; + if generation != generation_guard.captured_generation() { + return Ok(()); + } + let runner = engine.middleware_runner()?; + match apply_middleware_chain(req, client, ctx, chain, &runner, generation_guard).await? + { + MiddlewareApplyResult::Allowed(req) => req, + MiddlewareApplyResult::Denied(reason) => { + crate::l7::rest::RestProvider::default() + .deny_with_redacted_target( + &crate::l7::provider::L7Request { + action: "HTTP".into(), + target: redacted_target.clone(), + query_params: std::collections::HashMap::new(), + raw_header: Vec::new(), + body_length: crate::l7::provider::BodyLength::None, + }, + &ctx.policy_name, + &reason, + client, + Some(&redacted_target), + Some(crate::l7::rest::DenyResponseContext { + host: Some(&ctx.host), + port: Some(ctx.port), + binary: Some(&ctx.binary_path), + }), + ) + .await?; + return Ok(()); + } + } + } else { + req + }; + let req_with_auth = match crate::l7::token_grant_injection::inject_if_needed(req, ctx).await { Ok(req) => req, @@ -1901,6 +2373,64 @@ network_policies: (config, tunnel_engine, ctx, fixture) } + fn middleware_relay_context( + middleware_impl: &str, + on_error: &str, + ) -> (L7EndpointConfig, TunnelPolicyEngine, L7EvalContext) { + let data = format!( + r#" +network_middlewares: + - name: request-middleware + middleware: {middleware_impl} + on_error: {on_error} + endpoints: + include: ["api.example.test"] +network_policies: + rest_api: + name: rest_api + endpoints: + - host: api.example.test + port: 8080 + protocol: rest + enforcement: enforce + rules: + - allow: + method: POST + path: "/v1/**" + binaries: + - {{ path: /usr/bin/curl }} +"# + ); + let engine = OpaEngine::from_strings(TEST_POLICY, &data).unwrap(); + let input = NetworkInput { + host: "api.example.test".into(), + port: 8080, + binary_path: PathBuf::from("/usr/bin/curl"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let (endpoint_config, generation) = engine + .query_endpoint_config_with_generation(&input) + .unwrap(); + let config = crate::l7::parse_l7_config(&endpoint_config.unwrap()).unwrap(); + let tunnel_engine = engine.clone_engine_for_tunnel(generation).unwrap(); + let ctx = L7EvalContext { + host: "api.example.test".into(), + port: 8080, + policy_name: "rest_api".into(), + binary_path: "/usr/bin/curl".into(), + ancestors: vec![], + cmdline_paths: vec![], + secret_resolver: None, + activity_tx: None, + dynamic_credentials: None, + token_grant_resolver: None, + }; + + (config, tunnel_engine, ctx) + } + fn passthrough_token_grant_relay_context( resolver_response: std::result::Result<&str, &str>, ) -> ( @@ -2112,7 +2642,10 @@ network_policies: .unwrap(); let upstream_request = String::from_utf8_lossy(&upstream_request[..n]); - assert!(upstream_request.starts_with("GET /v1/projects HTTP/1.1\r\n")); + assert!( + upstream_request.starts_with("GET /v1/projects HTTP/1.1\r\n"), + "unexpected upstream request: {upstream_request:?}" + ); assert!(upstream_request.contains("Authorization: Bearer grant-token\r\n")); assert!(!upstream_request.contains("stale-token")); assert_eq!(authorization_header_count(&upstream_request), 1); @@ -2195,26 +2728,29 @@ network_policies: } #[tokio::test] - async fn passthrough_relay_injects_token_grant_authorization_header() { - let (generation_guard, ctx, fixture) = - passthrough_token_grant_relay_context(Ok("grant-token")); + async fn l7_rest_middleware_redacts_body_before_upstream() { + let (config, tunnel_engine, ctx) = + middleware_relay_context("openshell/secrets", "fail_closed"); let (mut app, mut relay_client) = tokio::io::duplex(8192); let (mut relay_upstream, mut upstream) = tokio::io::duplex(8192); let relay = tokio::spawn(async move { - relay_passthrough_with_credentials( + relay_with_inspection( + &config, + tunnel_engine, &mut relay_client, &mut relay_upstream, &ctx, - &generation_guard, ) .await }); - app.write_all( - b"GET /v1/projects HTTP/1.1\r\nHost: api.example.test\r\nAuthorization: Bearer stale-token\r\nConnection: close\r\n\r\n", - ) - .await - .unwrap(); + let body = br#"{"api_key":"sk-1234567890abcdef"}"#; + let request = format!( + "POST /v1/messages HTTP/1.1\r\nHost: api.example.test\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}", + body.len(), + std::str::from_utf8(body).unwrap() + ); + app.write_all(request.as_bytes()).await.unwrap(); let mut upstream_request = [0u8; 1024]; let n = tokio::time::timeout( @@ -2225,17 +2761,13 @@ network_policies: .expect("request should reach upstream") .unwrap(); let upstream_request = String::from_utf8_lossy(&upstream_request[..n]); - - assert!(upstream_request.starts_with("GET /v1/projects HTTP/1.1\r\n")); - assert!(upstream_request.contains("Authorization: Bearer grant-token\r\n")); - assert!(!upstream_request.contains("stale-token")); - assert_eq!(authorization_header_count(&upstream_request), 1); + assert!(upstream_request.contains(r#""api_key":"[REDACTED]""#)); + assert!(!upstream_request.contains("sk-1234567890abcdef")); upstream .write_all(b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\nConnection: close\r\n\r\n") .await .unwrap(); - let mut client_response = [0u8; 512]; let n = tokio::time::timeout( std::time::Duration::from_secs(1), @@ -2246,101 +2778,769 @@ network_policies: .unwrap(); assert!(String::from_utf8_lossy(&client_response[..n]).contains("204 No Content")); drop(app); - tokio::time::timeout(std::time::Duration::from_secs(1), relay) .await .expect("relay should finish") .unwrap() .unwrap(); - - fixture.assert_one_request("api.example.test\t8080\t/v1/**\tprovider:access_token"); } #[tokio::test] - async fn passthrough_relay_token_grant_failure_returns_bad_gateway_without_forwarding() { - let (generation_guard, ctx, fixture) = - passthrough_token_grant_relay_context(Err("oauth unavailable")); + async fn l7_rest_middleware_fail_closed_does_not_reach_upstream() { + let (config, tunnel_engine, ctx) = + middleware_relay_context("example/unavailable", "fail_closed"); let (mut app, mut relay_client) = tokio::io::duplex(8192); let (mut relay_upstream, mut upstream) = tokio::io::duplex(8192); let relay = tokio::spawn(async move { - relay_passthrough_with_credentials( + relay_with_inspection( + &config, + tunnel_engine, &mut relay_client, &mut relay_upstream, &ctx, - &generation_guard, ) .await }); app.write_all( - b"GET /v1/projects HTTP/1.1\r\nHost: api.example.test\r\nConnection: close\r\n\r\n", + b"POST /v1/messages HTTP/1.1\r\nHost: api.example.test\r\nContent-Length: 2\r\nConnection: close\r\n\r\n{}", ) .await .unwrap(); - tokio::time::timeout(std::time::Duration::from_secs(1), relay) + let mut response = [0u8; 512]; + let n = tokio::time::timeout(std::time::Duration::from_secs(1), app.read(&mut response)) .await - .expect("relay should finish") - .unwrap() + .expect("denial should reach client") .unwrap(); + let response = String::from_utf8_lossy(&response[..n]); + assert!(response.contains("403 Forbidden")); + assert!(response.contains("middleware_failed")); - let mut client_response = [0u8; 512]; - let n = tokio::time::timeout( - std::time::Duration::from_secs(1), - app.read(&mut client_response), - ) - .await - .expect("bad gateway response should reach client") - .unwrap(); - assert!(String::from_utf8_lossy(&client_response[..n]).contains("502 Bad Gateway")); - - let mut upstream_request = [0u8; 128]; - let n = tokio::time::timeout( - std::time::Duration::from_secs(1), + let mut upstream_request = [0u8; 32]; + let result = tokio::time::timeout( + std::time::Duration::from_millis(100), upstream.read(&mut upstream_request), ) - .await - .expect("upstream should close without forwarded data") - .unwrap(); - assert_eq!(n, 0, "unauthenticated request must not reach upstream"); + .await; + assert!( + matches!(result, Err(_) | Ok(Ok(0))), + "upstream should not receive request bytes" + ); - fixture.assert_one_request("api.example.test\t8080\t/v1/**\tprovider:access_token"); + drop(app); + tokio::time::timeout(std::time::Duration::from_secs(1), relay) + .await + .expect("relay should finish") + .unwrap() + .unwrap(); } - #[test] - fn websocket_text_policy_requires_explicit_message_rule() { + #[tokio::test] + async fn jsonrpc_middleware_fail_closed_does_not_reach_upstream() { let data = r#" +network_middlewares: + - name: request-middleware + middleware: example/unavailable + on_error: fail_closed + endpoints: + include: ["api.example.test"] network_policies: - ws_api: - name: ws_api + jsonrpc_api: + name: jsonrpc_api endpoints: - - host: gateway.example.test + - host: api.example.test port: 443 - protocol: websocket + protocol: json-rpc enforcement: enforce rules: - allow: - method: GET - path: "/ws" + method: reports.list binaries: - { path: /usr/bin/node } "#; let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); let input = NetworkInput { - host: "gateway.example.test".into(), + host: "api.example.test".into(), port: 443, binary_path: PathBuf::from("/usr/bin/node"), binary_sha256: "unused".into(), ancestors: vec![], cmdline_paths: vec![], }; - let generation = engine - .evaluate_network_action_with_generation(&input) - .unwrap() - .1; + let (endpoint_config, generation) = engine + .query_endpoint_config_with_generation(&input) + .expect("endpoint config"); + let config = crate::l7::parse_l7_config(&endpoint_config.expect("json-rpc config")) + .expect("parse JSON-RPC config"); let tunnel_engine = engine.clone_engine_for_tunnel(generation).unwrap(); let ctx = L7EvalContext { - host: "gateway.example.test".into(), + host: "api.example.test".into(), + port: 443, + policy_name: "jsonrpc_api".into(), + binary_path: "/usr/bin/node".into(), + ancestors: vec![], + cmdline_paths: vec![], + secret_resolver: None, + activity_tx: None, + dynamic_credentials: None, + token_grant_resolver: None, + }; + let (mut app, mut relay_client) = tokio::io::duplex(8192); + let (mut relay_upstream, mut upstream) = tokio::io::duplex(8192); + let relay = tokio::spawn(async move { + relay_jsonrpc( + &config, + &tunnel_engine, + &mut relay_client, + &mut relay_upstream, + &ctx, + ) + .await + }); + + let body = br#"{"jsonrpc":"2.0","id":1,"method":"reports.list"}"#; + let request = format!( + "POST /rpc HTTP/1.1\r\nHost: api.example.test\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}", + body.len(), + std::str::from_utf8(body).unwrap() + ); + app.write_all(request.as_bytes()).await.unwrap(); + + let mut response = [0u8; 512]; + let n = tokio::time::timeout(std::time::Duration::from_secs(1), app.read(&mut response)) + .await + .expect("denial should reach client") + .unwrap(); + let response = String::from_utf8_lossy(&response[..n]); + assert!(response.contains("403 Forbidden")); + assert!(response.contains("middleware_failed")); + + let mut upstream_request = [0u8; 32]; + let result = tokio::time::timeout( + std::time::Duration::from_millis(100), + upstream.read(&mut upstream_request), + ) + .await; + assert!( + matches!(result, Err(_) | Ok(Ok(0))), + "upstream should not receive request bytes" + ); + + drop(app); + tokio::time::timeout(std::time::Duration::from_secs(1), relay) + .await + .expect("relay should finish") + .unwrap() + .unwrap(); + } + + #[tokio::test] + async fn l7_rest_middleware_over_capacity_fails_closed() { + let (config, tunnel_engine, ctx) = + middleware_relay_context("openshell/secrets", "fail_closed"); + let (mut app, mut relay_client) = tokio::io::duplex(8192); + let (mut relay_upstream, mut upstream) = tokio::io::duplex(8192); + let relay = tokio::spawn(async move { + relay_with_inspection( + &config, + tunnel_engine, + &mut relay_client, + &mut relay_upstream, + &ctx, + ) + .await + }); + + // A declared body far above the 256 KiB inspection cap must be denied + // (fail-closed) before the body is read or reaches the upstream. + let request = format!( + "POST /v1/messages HTTP/1.1\r\nHost: api.example.test\r\nContent-Length: {}\r\nConnection: close\r\n\r\n", + 300 * 1024 + ); + app.write_all(request.as_bytes()).await.unwrap(); + + let mut response = [0u8; 512]; + let n = tokio::time::timeout(std::time::Duration::from_secs(1), app.read(&mut response)) + .await + .expect("denial should reach client") + .unwrap(); + let response = String::from_utf8_lossy(&response[..n]); + assert!(response.contains("403 Forbidden")); + assert!(response.contains("request_body_over_capacity")); + + let mut upstream_request = [0u8; 32]; + let result = tokio::time::timeout( + std::time::Duration::from_millis(100), + upstream.read(&mut upstream_request), + ) + .await; + assert!( + matches!(result, Err(_) | Ok(Ok(0))), + "upstream should not receive request bytes" + ); + + drop(app); + tokio::time::timeout(std::time::Duration::from_secs(1), relay) + .await + .expect("relay should finish") + .unwrap() + .unwrap(); + } + + #[tokio::test] + async fn over_capacity_resolution_honors_on_error() { + use openshell_supervisor_middleware::{ChainEntry, OnError}; + + let ctx = L7EvalContext { + host: "api.example.test".into(), + port: 443, + policy_name: "p".into(), + binary_path: "/usr/bin/curl".into(), + ancestors: vec![], + cmdline_paths: vec![], + secret_resolver: None, + activity_tx: None, + dynamic_credentials: None, + token_grant_resolver: None, + }; + let req = || crate::l7::provider::L7Request { + action: "POST".into(), + target: "/v1".into(), + query_params: std::collections::HashMap::new(), + raw_header: Vec::new(), + body_length: crate::l7::provider::BodyLength::None, + }; + let fail_open = ChainEntry { + name: "m".into(), + implementation: "openshell/secrets".into(), + config: prost_types::Struct::default(), + on_error: OnError::FailOpen, + }; + let fail_closed = ChainEntry { + on_error: OnError::FailClosed, + ..fail_open.clone() + }; + + let runner = openshell_supervisor_middleware::ChainRunner::default(); + let open_chain = runner + .describe_chain(std::slice::from_ref(&fail_open)) + .await + .expect("describe fail-open chain"); + let mixed_chain = runner + .describe_chain(&[fail_open.clone(), fail_closed]) + .await + .expect("describe mixed chain"); + + // Recoverable (Content-Length over cap, nothing consumed) + all fail-open + // -> stream through unprocessed. + assert!(matches!( + resolve_unbuffered_body(&ctx, req(), &open_chain, true), + MiddlewareApplyResult::Allowed(_) + )); + // Any fail-closed entry -> deny. + assert!(matches!( + resolve_unbuffered_body(&ctx, req(), &mixed_chain, true), + MiddlewareApplyResult::Denied(_) + )); + // Not recoverable (chunked overflow already consumed bytes) -> deny even + // when every entry is fail-open. + assert!(matches!( + resolve_unbuffered_body(&ctx, req(), &open_chain, false), + MiddlewareApplyResult::Denied(_) + )); + } + + #[test] + fn middleware_keeps_the_raw_request_query() { + let query = raw_query_from_request_headers( + b"POST /v1/messages?token=a%2Bb&scope=private HTTP/1.1\r\nHost: api.example.test\r\n\r\n", + ) + .expect("query from request headers"); + + assert_eq!(query, "token=a%2Bb&scope=private"); + } + + #[test] + fn middleware_request_input_preserves_plain_http_scheme() { + let req = crate::l7::provider::L7Request { + action: "POST".into(), + target: "/v1/messages".into(), + query_params: std::collections::HashMap::new(), + raw_header: Vec::new(), + body_length: crate::l7::provider::BodyLength::None, + }; + let ctx = L7EvalContext { + host: "api.example.test".into(), + port: 80, + policy_name: "api".into(), + binary_path: "/usr/bin/curl".into(), + ancestors: Vec::new(), + cmdline_paths: Vec::new(), + secret_resolver: None, + activity_tx: None, + dynamic_credentials: None, + token_grant_resolver: None, + }; + + let input = middleware_request_input( + "http", + &req, + &ctx, + BTreeMap::new(), + String::new(), + Vec::new(), + ); + + assert_eq!(input.scheme, "http"); + } + + /// Tracing layer that captures emitted `OcsfEvent`s for assertions. + struct OcsfCaptureLayer(Arc>>); + + impl tracing_subscriber::Layer for OcsfCaptureLayer { + fn on_event( + &self, + event: &tracing::Event<'_>, + _ctx: tracing_subscriber::layer::Context<'_, S>, + ) { + if event.metadata().target() == openshell_ocsf::OCSF_TARGET + && let Some(ocsf_event) = openshell_ocsf::clone_current_event() + { + self.0.lock().unwrap().push(ocsf_event); + } + } + } + + #[test] + fn middleware_ocsf_events_are_audit_safe() { + use openshell_supervisor_middleware::{ + ChainOutcome, MiddlewareInvocation, NamespacedFinding, + }; + use tracing_subscriber::layer::SubscriberExt; + + const RAW_SECRET: &str = "sk-RAWSECRETVALUE0123456789"; + + let events = Arc::new(std::sync::Mutex::new(Vec::new())); + let subscriber = tracing_subscriber::registry().with(OcsfCaptureLayer(Arc::clone(&events))); + let _guard = tracing::subscriber::set_default(subscriber); + + let ctx = L7EvalContext { + host: "api.example.test".into(), + port: 443, + policy_name: "rest_api".into(), + binary_path: "/usr/bin/curl".into(), + ancestors: vec![], + cmdline_paths: vec![], + secret_resolver: None, + activity_tx: None, + dynamic_credentials: None, + token_grant_resolver: None, + }; + let req = crate::l7::provider::L7Request { + action: "POST".into(), + target: "/v1/messages".into(), + query_params: std::collections::HashMap::new(), + raw_header: Vec::new(), + body_length: crate::l7::provider::BodyLength::None, + }; + let outcome = ChainOutcome { + allowed: true, + reason: String::new(), + // The transformed body still holds the raw secret; emission must never + // serialize it. + body: format!(r#"{{"api_key":"{RAW_SECRET}"}}"#).into_bytes(), + added_headers: BTreeMap::new(), + findings: vec![NamespacedFinding { + middleware: "redact-secrets".into(), + finding: openshell_core::proto::Finding { + r#type: "secret.common".into(), + label: "common secret pattern".into(), + count: 1, + confidence: "medium".into(), + severity: "medium".into(), + }, + }], + metadata: BTreeMap::new(), + applied: vec![MiddlewareInvocation { + name: "redact-secrets".into(), + implementation: "openshell/secrets".into(), + decision: openshell_core::proto::Decision::Allow, + transformed: true, + failed: false, + }], + }; + + emit_middleware_events(&ctx, &req, &outcome); + + let captured = events.lock().unwrap(); + // Per-invocation decisions are HTTP Activity (class 4002). + assert!( + captured.iter().any(|e| e.class_uid() == 4002), + "expected an HTTP Activity event for the middleware invocation" + ); + // Findings are Detection Finding (class 2004) with the finding's severity. + let finding_event = captured + .iter() + .find(|e| e.class_uid() == 2004) + .expect("expected a Detection Finding event"); + assert_eq!(finding_event.base().severity, SeverityId::Medium); + + // No raw payload material may appear in any emitted event. + let serialized = serde_json::to_string(&*captured).expect("serialize events"); + assert!( + !serialized.contains(RAW_SECRET), + "raw secret leaked into OCSF events: {serialized}" + ); + // Safe finding metadata is still present. + assert!(serialized.contains("secret.common")); + } + + #[tokio::test] + async fn passthrough_relay_runs_middleware_redaction() { + // A no-protocol endpoint takes the credential-injection passthrough path; + // host-selected middleware must still inspect and redact its body. + let data = r#" +network_middlewares: + - name: request-middleware + middleware: openshell/secrets + on_error: fail_closed + endpoints: + include: ["api.example.test"] +network_policies: + passthrough_api: + name: passthrough_api + endpoints: + - host: api.example.test + port: 8080 + binaries: + - { path: /usr/bin/curl } +"#; + let engine = Arc::new(OpaEngine::from_strings(TEST_POLICY, data).unwrap()); + let generation_guard = engine + .generation_guard(engine.current_generation()) + .unwrap(); + let ctx = L7EvalContext { + host: "api.example.test".into(), + port: 8080, + policy_name: "passthrough_api".into(), + binary_path: "/usr/bin/curl".into(), + ancestors: vec![], + cmdline_paths: vec![], + secret_resolver: None, + activity_tx: None, + dynamic_credentials: None, + token_grant_resolver: None, + }; + + let (mut app, mut relay_client) = tokio::io::duplex(8192); + let (mut relay_upstream, mut upstream) = tokio::io::duplex(8192); + let engine_task = Arc::clone(&engine); + let relay = tokio::spawn(async move { + relay_passthrough_with_credentials( + &mut relay_client, + &mut relay_upstream, + &ctx, + &generation_guard, + Some(engine_task.as_ref()), + ) + .await + }); + + let body = br#"{"api_key":"sk-1234567890abcdef"}"#; + let request = format!( + "POST /v1/messages HTTP/1.1\r\nHost: api.example.test\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}", + body.len(), + std::str::from_utf8(body).unwrap() + ); + app.write_all(request.as_bytes()).await.unwrap(); + + let mut upstream_request = [0u8; 1024]; + let n = tokio::time::timeout( + std::time::Duration::from_secs(1), + upstream.read(&mut upstream_request), + ) + .await + .expect("request should reach upstream") + .unwrap(); + let upstream_request = String::from_utf8_lossy(&upstream_request[..n]); + assert!( + upstream_request.contains(r#""api_key":"[REDACTED]""#), + "unexpected upstream request: {upstream_request:?}" + ); + assert!(!upstream_request.contains("sk-1234567890abcdef")); + + upstream + .write_all(b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\nConnection: close\r\n\r\n") + .await + .unwrap(); + let mut client_response = [0u8; 512]; + let n = tokio::time::timeout( + std::time::Duration::from_secs(1), + app.read(&mut client_response), + ) + .await + .expect("response should reach client") + .unwrap(); + assert!(String::from_utf8_lossy(&client_response[..n]).contains("204 No Content")); + drop(app); + tokio::time::timeout(std::time::Duration::from_secs(1), relay) + .await + .expect("relay should finish") + .unwrap() + .unwrap(); + } + + #[tokio::test] + async fn websocket_upgrade_request_is_inspected_and_denied() { + // The WebSocket upgrade handshake is an HTTP request the hook can inspect + // and deny: a fail-closed middleware blocks the upgrade before it is + // forwarded. + let data = r#" +network_middlewares: + - name: request-middleware + middleware: example/unavailable + on_error: fail_closed + endpoints: + include: ["gateway.example.test"] +network_policies: + ws_api: + name: ws_api + endpoints: + - host: gateway.example.test + port: 443 + protocol: websocket + enforcement: enforce + rules: + - allow: + method: GET + path: "/ws" + binaries: + - { path: /usr/bin/node } +"#; + let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); + let input = NetworkInput { + host: "gateway.example.test".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/node"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let (endpoint_config, generation) = engine + .query_endpoint_config_with_generation(&input) + .unwrap(); + let config = crate::l7::parse_l7_config(&endpoint_config.unwrap()).unwrap(); + let tunnel_engine = engine.clone_engine_for_tunnel(generation).unwrap(); + let ctx = L7EvalContext { + host: "gateway.example.test".into(), + port: 443, + policy_name: "ws_api".into(), + binary_path: "/usr/bin/node".into(), + ancestors: vec![], + cmdline_paths: vec![], + secret_resolver: None, + activity_tx: None, + dynamic_credentials: None, + token_grant_resolver: None, + }; + + let (mut app, mut relay_client) = tokio::io::duplex(8192); + let (mut relay_upstream, mut upstream) = tokio::io::duplex(8192); + let relay = tokio::spawn(async move { + relay_with_inspection( + &config, + tunnel_engine, + &mut relay_client, + &mut relay_upstream, + &ctx, + ) + .await + }); + + app.write_all( + b"GET /ws HTTP/1.1\r\nHost: gateway.example.test\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n", + ) + .await + .unwrap(); + + let mut response = [0u8; 512]; + let n = tokio::time::timeout(std::time::Duration::from_secs(1), app.read(&mut response)) + .await + .expect("denial should reach client") + .unwrap(); + let response = String::from_utf8_lossy(&response[..n]); + assert!(response.contains("403 Forbidden")); + assert!(response.contains("middleware_failed")); + + let mut upstream_request = [0u8; 32]; + let result = tokio::time::timeout( + std::time::Duration::from_millis(100), + upstream.read(&mut upstream_request), + ) + .await; + assert!( + matches!(result, Err(_) | Ok(Ok(0))), + "upstream should not receive the upgrade request" + ); + + drop(app); + tokio::time::timeout(std::time::Duration::from_secs(1), relay) + .await + .expect("relay should finish") + .unwrap() + .unwrap(); + } + + #[tokio::test] + async fn passthrough_relay_injects_token_grant_authorization_header() { + let (generation_guard, ctx, fixture) = + passthrough_token_grant_relay_context(Ok("grant-token")); + let (mut app, mut relay_client) = tokio::io::duplex(8192); + let (mut relay_upstream, mut upstream) = tokio::io::duplex(8192); + let relay = tokio::spawn(async move { + relay_passthrough_with_credentials( + &mut relay_client, + &mut relay_upstream, + &ctx, + &generation_guard, + None, + ) + .await + }); + + app.write_all( + b"GET /v1/projects HTTP/1.1\r\nHost: api.example.test\r\nAuthorization: Bearer stale-token\r\nConnection: close\r\n\r\n", + ) + .await + .unwrap(); + + let mut upstream_request = [0u8; 1024]; + let n = tokio::time::timeout( + std::time::Duration::from_secs(1), + upstream.read(&mut upstream_request), + ) + .await + .expect("request should reach upstream") + .unwrap(); + let upstream_request = String::from_utf8_lossy(&upstream_request[..n]); + + assert!(upstream_request.starts_with("GET /v1/projects HTTP/1.1\r\n")); + assert!(upstream_request.contains("Authorization: Bearer grant-token\r\n")); + assert!(!upstream_request.contains("stale-token")); + assert_eq!(authorization_header_count(&upstream_request), 1); + + upstream + .write_all(b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\nConnection: close\r\n\r\n") + .await + .unwrap(); + + let mut client_response = [0u8; 512]; + let n = tokio::time::timeout( + std::time::Duration::from_secs(1), + app.read(&mut client_response), + ) + .await + .expect("response should reach client") + .unwrap(); + assert!(String::from_utf8_lossy(&client_response[..n]).contains("204 No Content")); + drop(app); + + tokio::time::timeout(std::time::Duration::from_secs(1), relay) + .await + .expect("relay should finish") + .unwrap() + .unwrap(); + + fixture.assert_one_request("api.example.test\t8080\t/v1/**\tprovider:access_token"); + } + + #[tokio::test] + async fn passthrough_relay_token_grant_failure_returns_bad_gateway_without_forwarding() { + let (generation_guard, ctx, fixture) = + passthrough_token_grant_relay_context(Err("oauth unavailable")); + let (mut app, mut relay_client) = tokio::io::duplex(8192); + let (mut relay_upstream, mut upstream) = tokio::io::duplex(8192); + let relay = tokio::spawn(async move { + relay_passthrough_with_credentials( + &mut relay_client, + &mut relay_upstream, + &ctx, + &generation_guard, + None, + ) + .await + }); + + app.write_all( + b"GET /v1/projects HTTP/1.1\r\nHost: api.example.test\r\nConnection: close\r\n\r\n", + ) + .await + .unwrap(); + + tokio::time::timeout(std::time::Duration::from_secs(1), relay) + .await + .expect("relay should finish") + .unwrap() + .unwrap(); + + let mut client_response = [0u8; 512]; + let n = tokio::time::timeout( + std::time::Duration::from_secs(1), + app.read(&mut client_response), + ) + .await + .expect("bad gateway response should reach client") + .unwrap(); + assert!(String::from_utf8_lossy(&client_response[..n]).contains("502 Bad Gateway")); + + let mut upstream_request = [0u8; 128]; + let n = tokio::time::timeout( + std::time::Duration::from_secs(1), + upstream.read(&mut upstream_request), + ) + .await + .expect("upstream should close without forwarded data") + .unwrap(); + assert_eq!(n, 0, "unauthenticated request must not reach upstream"); + + fixture.assert_one_request("api.example.test\t8080\t/v1/**\tprovider:access_token"); + } + + #[test] + fn websocket_text_policy_requires_explicit_message_rule() { + let data = r#" +network_policies: + ws_api: + name: ws_api + endpoints: + - host: gateway.example.test + port: 443 + protocol: websocket + enforcement: enforce + rules: + - allow: + method: GET + path: "/ws" + binaries: + - { path: /usr/bin/node } +"#; + let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); + let input = NetworkInput { + host: "gateway.example.test".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/node"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let generation = engine + .evaluate_network_action_with_generation(&input) + .unwrap() + .1; + let tunnel_engine = engine.clone_engine_for_tunnel(generation).unwrap(); + let ctx = L7EvalContext { + host: "gateway.example.test".into(), port: 443, policy_name: "ws_api".into(), binary_path: "/usr/bin/node".into(), @@ -3173,7 +4373,10 @@ network_policies: .expect("first request should reach upstream") .unwrap(); let first_upstream = String::from_utf8_lossy(&first_upstream[..n]); - assert!(first_upstream.starts_with("POST /write HTTP/1.1")); + assert!( + first_upstream.starts_with("POST /write HTTP/1.1"), + "unexpected upstream request: {first_upstream:?}" + ); upstream .write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\nConnection: keep-alive\r\n\r\nOK") @@ -3243,6 +4446,7 @@ network_policies: &mut relay_upstream, &ctx, &generation_guard, + None, ) .await }); diff --git a/crates/openshell-supervisor-network/src/l7/rest.rs b/crates/openshell-supervisor-network/src/l7/rest.rs index 0558a67e5..15825d1b2 100644 --- a/crates/openshell-supervisor-network/src/l7/rest.rs +++ b/crates/openshell-supervisor-network/src/l7/rest.rs @@ -27,6 +27,19 @@ const MAX_REWRITE_BODY_BYTES: usize = 256 * 1024; /// Maximum body bytes for `SigV4` body-signing mode. Larger than the credential /// rewrite limit because Bedrock payloads can be several megabytes. const MAX_SIGV4_BODY_BYTES: usize = 10 * 1024 * 1024; +#[cfg(test)] +async fn max_middleware_body_bytes() -> usize { + let chain = openshell_supervisor_middleware::ChainRunner::default() + .describe_chain(&[openshell_supervisor_middleware::ChainEntry { + name: "test".into(), + implementation: openshell_supervisor_middleware::BUILTIN_SECRETS.into(), + config: prost_types::Struct::default(), + on_error: openshell_supervisor_middleware::OnError::FailClosed, + }]) + .await + .expect("describe built-in middleware"); + chain[0].max_body_bytes() +} const RELAY_BUF_SIZE: usize = 8192; const HTTP_METHOD_PREFIXES: &[&[u8]] = &[ b"GET ", @@ -245,6 +258,36 @@ async fn parse_http_request( })) } +/// Build an L7 request from a request already buffered by another proxy path. +/// +/// The forward proxy needs this after it has consumed the incoming HTTP/1 +/// headers itself. Keep the framing and query parsing here so it matches the +/// stream-based REST parser rather than growing another local parser. +pub(crate) fn request_from_buffered_http( + action: impl Into, + target: impl Into, + query_target: &str, + raw_header: Vec, +) -> Result { + let header_end = raw_header + .windows(4) + .position(|window| window == b"\r\n\r\n") + .ok_or_else(|| miette!("HTTP request headers are missing the CRLF terminator"))? + + 4; + let header_str = std::str::from_utf8(&raw_header[..header_end]) + .map_err(|_| miette!("HTTP headers contain invalid UTF-8"))?; + let body_length = parse_body_length(header_str)?; + let (_, query_params) = parse_target_query(query_target)?; + + Ok(L7Request { + action: action.into(), + target: target.into(), + query_params, + raw_header, + body_length, + }) +} + /// Rebuild the request line in a raw HTTP header block with a canonicalized /// target. Called when the canonical path differs from what the client sent, /// so the upstream dispatches on the exact bytes the policy engine evaluated. @@ -768,6 +811,112 @@ struct PreparedRequestBody { body: Vec, } +pub(crate) struct BufferedRequestBody { + pub(crate) headers: Vec, + pub(crate) body: Vec, +} + +/// Result of attempting to buffer a request body for middleware inspection. +pub(crate) enum BufferResult { + /// The full body was buffered within the size cap. + Buffered(BufferedRequestBody), + /// The body exceeded the inspection cap. `recoverable` is true when no body + /// bytes were consumed yet (a declared `Content-Length` over the cap), so the + /// request can still be streamed through unprocessed under fail-open. It is + /// false once bytes have been consumed (chunked overflow), where denying is + /// the only safe outcome. + OverCapacity { recoverable: bool }, +} + +pub(crate) async fn buffer_request_body_for_middleware( + req: &L7Request, + client: &mut C, + generation_guard: Option<&PolicyGenerationGuard>, + max_body_bytes: usize, +) -> Result { + let header_end = req + .raw_header + .windows(4) + .position(|w| w == b"\r\n\r\n") + .map_or(req.raw_header.len(), |p| p + 4); + let headers = req.raw_header[..header_end].to_vec(); + let already_read = &req.raw_header[header_end..]; + match req.body_length { + BodyLength::None => Ok(BufferResult::Buffered(BufferedRequestBody { + headers, + body: already_read.to_vec(), + })), + BodyLength::ContentLength(len) => { + // The declared length is known before any further reads, so an + // over-cap body here has not consumed the stream and can be passed + // through unprocessed if every middleware is fail-open. + let Ok(len) = usize::try_from(len) else { + return Ok(BufferResult::OverCapacity { recoverable: true }); + }; + if len > max_body_bytes { + return Ok(BufferResult::OverCapacity { recoverable: true }); + } + let initial_len = already_read.len().min(len); + let mut body = Vec::with_capacity(len); + body.extend_from_slice(&already_read[..initial_len]); + let mut remaining = len.saturating_sub(initial_len); + let mut buf = [0u8; RELAY_BUF_SIZE]; + while remaining > 0 { + let to_read = remaining.min(buf.len()); + let n = client.read(&mut buf[..to_read]).await.into_diagnostic()?; + if n == 0 { + return Err(miette!( + "Connection closed with {remaining} body bytes remaining" + )); + } + if let Some(guard) = generation_guard { + guard.ensure_current()?; + } + body.extend_from_slice(&buf[..n]); + remaining -= n; + } + Ok(BufferResult::Buffered(BufferedRequestBody { + headers, + body, + })) + } + BodyLength::Chunked => { + // Chunked bodies are decoded incrementally into the payload bytes + // middleware expects, but the middleware cap counts the complete + // wire representation, including framing and trailers. On overflow, + // we have already consumed wire bytes from the client stream and + // cannot re-enter the normal raw relay path without a separate + // splice-through buffer. + Ok( + collect_chunked_body(client, already_read, generation_guard, Some(max_body_bytes)) + .await + .map_or(BufferResult::OverCapacity { recoverable: false }, |body| { + BufferResult::Buffered(BufferedRequestBody { headers, body }) + }), + ) + } + } +} + +pub(crate) fn rebuild_request_with_buffered_body( + req: &L7Request, + headers: &[u8], + body: &[u8], + add_headers: &std::collections::BTreeMap, +) -> Result { + let mut header_bytes = set_content_length(headers, body.len())?; + header_bytes = strip_header(&header_bytes, "transfer-encoding")?; + header_bytes = append_headers(&header_bytes, add_headers); + header_bytes.extend_from_slice(body); + Ok(L7Request { + action: req.action.clone(), + target: req.target.clone(), + query_params: req.query_params.clone(), + raw_header: header_bytes, + body_length: BodyLength::ContentLength(body.len() as u64), + }) +} + async fn collect_and_rewrite_request_body( req: &L7Request, client: &mut C, @@ -821,16 +970,12 @@ async fn collect_and_rewrite_request_body( Ok(PreparedRequestBody { headers, body }) } BodyLength::Chunked => { - let body = collect_chunked_body(client, already_read, generation_guard).await?; - if body_bytes_contain_reserved_marker(&body) { - return Err(miette!( - "request body credential rewrite does not support chunked bodies containing credential placeholders" - )); - } - Ok(PreparedRequestBody { - headers: rewritten_headers.to_vec(), - body, - }) + let body = collect_chunked_body(client, already_read, generation_guard, None).await?; + let (mut headers, body) = + rewrite_buffered_body(rewritten_headers, original_header_str, body, resolver)?; + headers = set_content_length(&headers, body.len())?; + headers = strip_header(&headers, "transfer-encoding")?; + Ok(PreparedRequestBody { headers, body }) } } } @@ -997,38 +1142,20 @@ async fn collect_chunked_body( client: &mut C, already_read: &[u8], generation_guard: Option<&PolicyGenerationGuard>, + max_wire_bytes: Option, ) -> Result> { - let mut read_buf = [0u8; RELAY_BUF_SIZE]; - let mut parse_buf = Vec::from(already_read); - let mut pos = 0usize; + let mut read_state = ChunkedReadState { + buffered_pos: 0, + wire_bytes: 0, + max_wire_bytes, + }; + let mut body = Vec::new(); loop { - if parse_buf.len() > MAX_REWRITE_BODY_BYTES { - return Err(miette!( - "request body credential rewrite buffers at most {MAX_REWRITE_BODY_BYTES} bytes" - )); - } - - let size_line_end = loop { - if let Some(end) = find_crlf(&parse_buf, pos) { - break end; - } - let n = client.read(&mut read_buf).await.into_diagnostic()?; - if n == 0 { - return Err(miette!("Chunked body ended before chunk-size line")); - } - if let Some(guard) = generation_guard { - guard.ensure_current()?; - } - parse_buf.extend_from_slice(&read_buf[..n]); - if parse_buf.len() > MAX_REWRITE_BODY_BYTES { - return Err(miette!( - "request body credential rewrite buffers at most {MAX_REWRITE_BODY_BYTES} bytes" - )); - } - }; - - let size_line = std::str::from_utf8(&parse_buf[pos..size_line_end]) + let size_line = read_chunked_line(client, already_read, &mut read_state, generation_guard) + .await + .map_err(|e| miette!("Chunked body ended before chunk-size line: {e}"))?; + let size_line = std::str::from_utf8(&size_line) .into_diagnostic() .map_err(|_| miette!("Invalid UTF-8 in chunk-size line"))?; let size_token = size_line @@ -1039,64 +1166,127 @@ async fn collect_chunked_body( let chunk_size = usize::from_str_radix(size_token, 16) .into_diagnostic() .map_err(|_| miette!("Invalid chunk size token: {size_token:?}"))?; - pos = size_line_end + 2; if chunk_size == 0 { loop { - let trailer_end = loop { - if let Some(end) = find_crlf(&parse_buf, pos) { - break end; - } - let n = client.read(&mut read_buf).await.into_diagnostic()?; - if n == 0 { - return Err(miette!("Chunked body ended before trailer terminator")); - } - if let Some(guard) = generation_guard { - guard.ensure_current()?; - } - parse_buf.extend_from_slice(&read_buf[..n]); - if parse_buf.len() > MAX_REWRITE_BODY_BYTES { - return Err(miette!( - "request body credential rewrite buffers at most {MAX_REWRITE_BODY_BYTES} bytes" - )); - } - }; - let trailer_line = &parse_buf[pos..trailer_end]; - pos = trailer_end + 2; + let trailer_line = + read_chunked_line(client, already_read, &mut read_state, generation_guard) + .await + .map_err(|e| { + miette!("Chunked body ended before trailer terminator: {e}") + })?; if trailer_line.is_empty() { - return Ok(parse_buf); + return Ok(body); } } } - let chunk_end = pos - .checked_add(chunk_size) - .ok_or_else(|| miette!("Chunk size overflow"))?; - let chunk_with_crlf_end = chunk_end - .checked_add(2) - .ok_or_else(|| miette!("Chunk size overflow"))?; - while parse_buf.len() < chunk_with_crlf_end { - let n = client.read(&mut read_buf).await.into_diagnostic()?; - if n == 0 { - return Err(miette!("Chunked body ended mid-chunk")); - } - if let Some(guard) = generation_guard { - guard.ensure_current()?; - } - parse_buf.extend_from_slice(&read_buf[..n]); - if parse_buf.len() > MAX_REWRITE_BODY_BYTES { - return Err(miette!( - "request body credential rewrite buffers at most {MAX_REWRITE_BODY_BYTES} bytes" - )); - } + if body.len().saturating_add(chunk_size) > MAX_REWRITE_BODY_BYTES { + return Err(miette!( + "request body credential rewrite buffers at most {MAX_REWRITE_BODY_BYTES} bytes" + )); } - if &parse_buf[chunk_end..chunk_with_crlf_end] != b"\r\n" { + read_buffered_exact( + client, + already_read, + &mut read_state, + chunk_size, + &mut body, + generation_guard, + ) + .await + .map_err(|e| miette!("Chunked body ended mid-chunk: {e}"))?; + + let mut chunk_crlf = Vec::with_capacity(2); + read_buffered_exact( + client, + already_read, + &mut read_state, + 2, + &mut chunk_crlf, + generation_guard, + ) + .await + .map_err(|e| miette!("Chunked body ended before chunk terminator: {e}"))?; + if chunk_crlf.as_slice() != b"\r\n" { return Err(miette!("Chunk missing terminating CRLF")); } - pos = chunk_with_crlf_end; } } +struct ChunkedReadState { + buffered_pos: usize, + wire_bytes: usize, + max_wire_bytes: Option, +} + +async fn read_chunked_line( + client: &mut C, + already_read: &[u8], + state: &mut ChunkedReadState, + generation_guard: Option<&PolicyGenerationGuard>, +) -> Result> { + let mut line = Vec::new(); + loop { + let byte = read_buffered_byte(client, already_read, state, generation_guard).await?; + line.push(byte); + if line.len() > MAX_REWRITE_BODY_BYTES { + return Err(miette!( + "request body credential rewrite buffers at most {MAX_REWRITE_BODY_BYTES} bytes" + )); + } + if line.ends_with(b"\r\n") { + line.truncate(line.len() - 2); + return Ok(line); + } + } +} + +async fn read_buffered_exact( + client: &mut C, + already_read: &[u8], + state: &mut ChunkedReadState, + len: usize, + out: &mut Vec, + generation_guard: Option<&PolicyGenerationGuard>, +) -> Result<()> { + for _ in 0..len { + let byte = read_buffered_byte(client, already_read, state, generation_guard).await?; + out.push(byte); + } + Ok(()) +} + +async fn read_buffered_byte( + client: &mut C, + already_read: &[u8], + state: &mut ChunkedReadState, + generation_guard: Option<&PolicyGenerationGuard>, +) -> Result { + if state + .max_wire_bytes + .is_some_and(|max| state.wire_bytes >= max) + { + return Err(miette!( + "chunked body wire representation exceeds middleware buffer limit" + )); + } + + let byte = if state.buffered_pos < already_read.len() { + let byte = already_read[state.buffered_pos]; + state.buffered_pos += 1; + byte + } else { + let byte = client.read_u8().await.into_diagnostic()?; + if let Some(guard) = generation_guard { + guard.ensure_current()?; + } + byte + }; + state.wire_bytes += 1; + Ok(byte) +} + fn content_type(headers: &str) -> Option { headers.lines().skip(1).find_map(|line| { let (name, value) = line.split_once(':')?; @@ -1160,6 +1350,50 @@ fn set_content_length(headers: &[u8], len: usize) -> Result> { Ok(out.into_bytes()) } +fn strip_header(headers: &[u8], strip_name: &str) -> Result> { + let header_str = + std::str::from_utf8(headers).map_err(|_| miette!("HTTP headers contain invalid UTF-8"))?; + let mut out = String::with_capacity(header_str.len()); + for line in header_str.split("\r\n") { + if line.is_empty() { + out.push_str("\r\n"); + break; + } + if line + .split_once(':') + .is_some_and(|(name, _)| name.trim().eq_ignore_ascii_case(strip_name)) + { + continue; + } + out.push_str(line); + out.push_str("\r\n"); + } + Ok(out.into_bytes()) +} + +fn append_headers( + headers: &[u8], + add_headers: &std::collections::BTreeMap, +) -> Vec { + if add_headers.is_empty() { + return headers.to_vec(); + } + let split = headers + .windows(4) + .position(|w| w == b"\r\n\r\n") + .map_or(headers.len(), |pos| pos); + let mut out = Vec::with_capacity(headers.len() + add_headers.len() * 32); + out.extend_from_slice(&headers[..split]); + for (name, value) in add_headers { + out.extend_from_slice(b"\r\n"); + out.extend_from_slice(name.as_bytes()); + out.extend_from_slice(b": "); + out.extend_from_slice(value.as_bytes()); + } + out.extend_from_slice(b"\r\n\r\n"); + out +} + pub(crate) fn request_is_websocket_upgrade(raw_header: &[u8]) -> bool { let header_end = raw_header .windows(4) @@ -2850,6 +3084,39 @@ mod tests { } } + #[test] + fn buffered_request_parser_uses_shared_framing_and_query_parsing() { + let request = request_from_buffered_http( + "POST", + "/v1/items", + "/v1/items?tag=first&tag=second", + b"POST /v1/items?tag=first&tag=second HTTP/1.1\r\nHost: api.example.com\r\nContent-Length: 3\r\n\r\nabc" + .to_vec(), + ) + .expect("parse buffered request"); + + assert_eq!(request.action, "POST"); + assert_eq!(request.target, "/v1/items"); + assert_eq!( + request.query_params.get("tag"), + Some(&vec!["first".to_string(), "second".to_string()]) + ); + assert!(matches!(request.body_length, BodyLength::ContentLength(3))); + } + + #[test] + fn buffered_request_parser_rejects_missing_header_terminator() { + let err = request_from_buffered_http( + "GET", + "/v1/items", + "/v1/items", + b"GET /v1/items HTTP/1.1\r\nHost: api.example.com\r\n".to_vec(), + ) + .expect_err("unterminated headers must be rejected"); + + assert!(err.to_string().contains("missing the CRLF terminator")); + } + #[test] fn parse_chunked() { let headers = @@ -3029,6 +3296,55 @@ mod tests { } } + #[tokio::test] + async fn collect_chunked_body_decodes_payload_bytes() { + let mut client = tokio::io::empty(); + let body = collect_chunked_body( + &mut client, + b"5\r\nhello\r\n6;ext=value\r\n world\r\n0\r\nx-checksum: abc\r\n\r\n", + None, + None, + ) + .await + .expect("chunked body should decode"); + + assert_eq!(body, b"hello world"); + } + + #[tokio::test] + async fn middleware_chunked_wire_body_at_cap_is_allowed() { + let max_body_bytes = max_middleware_body_bytes().await; + let payload_len = max_body_bytes - 14; + let mut wire = format!("{payload_len:x}\r\n").into_bytes(); + wire.extend(std::iter::repeat_n(b'x', payload_len)); + wire.extend_from_slice(b"\r\n0\r\n\r\n"); + assert_eq!(wire.len(), max_body_bytes); + + let body = collect_chunked_body(&mut tokio::io::empty(), &wire, None, Some(max_body_bytes)) + .await + .expect("wire representation at the cap should be allowed"); + + assert_eq!(body.len(), payload_len); + } + + #[tokio::test] + async fn middleware_chunked_wire_body_over_cap_is_rejected() { + let max_body_bytes = max_middleware_body_bytes().await; + let payload_len = max_body_bytes - 13; + let mut wire = format!("{payload_len:x}\r\n").into_bytes(); + wire.extend(std::iter::repeat_n(b'x', payload_len)); + wire.extend_from_slice(b"\r\n0\r\n\r\n"); + assert_eq!(wire.len(), max_body_bytes + 1); + assert!(payload_len < max_body_bytes); + + let error = + collect_chunked_body(&mut tokio::io::empty(), &wire, None, Some(max_body_bytes)) + .await + .expect_err("wire framing over the cap must be rejected"); + + assert!(error.to_string().contains("wire representation")); + } + /// SEC-009: Bare LF in headers enables header injection. #[tokio::test] async fn reject_bare_lf_in_headers() { @@ -5135,6 +5451,38 @@ mod tests { assert!(!forwarded.contains("OPENSHELL-RESOLVE-ENV")); } + #[tokio::test] + async fn relay_request_body_rewrite_normalizes_chunked_payload() { + let (_, resolver) = SecretResolver::from_provider_env( + [("API_TOKEN".to_string(), "provider-real-token".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let alias = "provider.v1-OPENSHELL-RESOLVE-ENV-API_TOKEN"; + let raw = format!( + "POST /api/messages HTTP/1.1\r\n\ + Host: api.example.com\r\n\ + Authorization: Bearer {alias}\r\n\ + Transfer-Encoding: chunked\r\n\r\n\ + 5\r\nhello\r\n0\r\n\r\n", + ); + + let forwarded = relay_and_capture_with_options( + raw.into_bytes(), + BodyLength::Chunked, + Some(&resolver), + true, + ) + .await + .expect("relay should succeed"); + + assert!(forwarded.contains("Authorization: Bearer provider-real-token\r\n")); + assert!(forwarded.contains("Content-Length: 5\r\n")); + assert!(!forwarded.contains("Transfer-Encoding: chunked\r\n")); + assert!(forwarded.ends_with("hello")); + } + #[tokio::test] async fn relay_request_body_rewrites_percent_encoded_canonical_urlencoded_token() { let (_, resolver) = SecretResolver::from_provider_env( diff --git a/crates/openshell-supervisor-network/src/opa.rs b/crates/openshell-supervisor-network/src/opa.rs index fbab5fedd..9e1427dcd 100644 --- a/crates/openshell-supervisor-network/src/opa.rs +++ b/crates/openshell-supervisor-network/src/opa.rs @@ -13,9 +13,11 @@ use openshell_core::policy::{ }; use openshell_core::proto::SandboxPolicy as ProtoSandboxPolicy; use openshell_policy::L7ConfigStanza; +use openshell_supervisor_middleware::{ChainEntry, ChainRunner, MiddlewareRegistry}; +use std::collections::HashSet; use std::path::{Path, PathBuf}; use std::sync::{ - Arc, Mutex, + Arc, Mutex, RwLock, atomic::{AtomicU64, Ordering}, }; @@ -71,6 +73,7 @@ pub struct SandboxConfig { pub struct OpaEngine { engine: Mutex, generation: Arc, + middleware_runner: RwLock, } /// Generation guard captured when an HTTP tunnel or request path starts. @@ -110,6 +113,7 @@ impl PolicyGenerationGuard { pub struct TunnelPolicyEngine { engine: Mutex, generation_guard: PolicyGenerationGuard, + middleware_runner: ChainRunner, } impl TunnelPolicyEngine { @@ -132,6 +136,19 @@ impl TunnelPolicyEngine { pub(crate) fn engine(&self) -> &Mutex { &self.engine } + + pub(crate) fn middleware_runner(&self) -> &ChainRunner { + &self.middleware_runner + } + + /// Query the ordered middleware chain for a destination within this tunnel. + pub fn query_middleware_chain(&self, input: &NetworkInput) -> Result> { + let mut engine = self + .engine + .lock() + .map_err(|_| miette::miette!("OPA engine lock poisoned"))?; + query_middleware_chain_locked(&mut engine, input) + } } impl OpaEngine { @@ -153,6 +170,7 @@ impl OpaEngine { Ok(Self { engine: Mutex::new(engine), generation: Arc::new(AtomicU64::new(0)), + middleware_runner: RwLock::new(ChainRunner::default()), }) } @@ -171,6 +189,7 @@ impl OpaEngine { Ok(Self { engine: Mutex::new(engine), generation: Arc::new(AtomicU64::new(0)), + middleware_runner: RwLock::new(ChainRunner::default()), }) } @@ -193,13 +212,21 @@ impl OpaEngine { /// gap between user-specified symlink paths (e.g., `/usr/bin/python3`) and /// kernel-resolved canonical paths (e.g., `/usr/bin/python3.11`). pub fn from_proto_with_pid(proto: &ProtoSandboxPolicy, entrypoint_pid: u32) -> Result { + if let Err(violations) = openshell_policy::validate_sandbox_policy(proto) { + let errors = violations + .iter() + .map(ToString::to_string) + .collect::>() + .join("\n"); + return Err(miette::miette!("policy validation failed:\n{errors}")); + } + let data_json_str = proto_to_opa_data_json(proto, entrypoint_pid); // Parse back to Value for preprocessing, then re-serialize let mut data: serde_json::Value = serde_json::from_str(&data_json_str) .map_err(|e| miette::miette!("internal: failed to parse proto JSON: {e}"))?; - // Validate BEFORE expanding presets let (errors, warnings) = crate::l7::validate_l7_policies(&data); for w in &warnings { openshell_ocsf::ocsf_emit!( @@ -235,6 +262,7 @@ impl OpaEngine { Ok(Self { engine: Mutex::new(engine), generation: Arc::new(AtomicU64::new(0)), + middleware_runner: RwLock::new(ChainRunner::default()), }) } @@ -432,6 +460,25 @@ impl OpaEngine { self.generation.load(Ordering::Acquire) } + /// Replace the complete middleware service registry and invalidate + /// existing tunnels so subsequent requests use the new service set. + pub fn replace_middleware_registry(&self, registry: MiddlewareRegistry) -> Result<()> { + let mut runner = self + .middleware_runner + .write() + .map_err(|_| miette::miette!("middleware runner lock poisoned"))?; + *runner = ChainRunner::from_registry(registry); + self.generation.fetch_add(1, Ordering::AcqRel); + Ok(()) + } + + pub(crate) fn middleware_runner(&self) -> Result { + self.middleware_runner + .read() + .map(|runner| runner.clone()) + .map_err(|_| miette::miette!("middleware runner lock poisoned")) + } + /// Return a guard for a previously captured policy generation. pub fn generation_guard(&self, expected_generation: u64) -> Result { let generation = self.current_generation(); @@ -548,6 +595,20 @@ impl OpaEngine { } } + /// Query the ordered middleware chain for an admitted destination. + pub fn query_middleware_chain_with_generation( + &self, + input: &NetworkInput, + ) -> Result<(Vec, u64)> { + let mut engine = self + .engine + .lock() + .map_err(|_| miette::miette!("OPA engine lock poisoned"))?; + let generation = self.current_generation(); + let chain = query_middleware_chain_locked(&mut engine, input)?; + Ok((chain, generation)) + } + /// Query `allowed_ips` from the matched endpoint config for a given request. /// /// Returns the list of CIDR/IP strings from the endpoint's `allowed_ips` @@ -629,6 +690,7 @@ impl OpaEngine { captured_generation: generation, current_generation: Arc::clone(&self.generation), }, + middleware_runner: self.middleware_runner()?, }) } } @@ -687,6 +749,157 @@ fn get_str_array(val: ®orus::Value, key: &str) -> Vec { } } +fn network_input_json(input: &NetworkInput) -> serde_json::Value { + let ancestor_strs: Vec = input + .ancestors + .iter() + .map(|p| p.to_string_lossy().into_owned()) + .collect(); + let cmdline_strs: Vec = input + .cmdline_paths + .iter() + .map(|p| p.to_string_lossy().into_owned()) + .collect(); + serde_json::json!({ + "exec": { + "path": input.binary_path.to_string_lossy(), + "ancestors": ancestor_strs, + "cmdline_paths": cmdline_strs, + }, + "network": { + "host": input.host, + "port": input.port, + } + }) +} + +fn query_middleware_chain_locked( + engine: &mut regorus::Engine, + input: &NetworkInput, +) -> Result> { + engine + .set_input_json(&network_input_json(input).to_string()) + .map_err(|e| miette::miette!("{e}"))?; + + let configs_val = engine + .eval_rule("data.openshell.sandbox.network_middlewares".into()) + .map_err(|e| miette::miette!("{e}"))?; + let configs = parse_middleware_configs(&configs_val)?; + if configs.is_empty() { + return Ok(Vec::new()); + } + global_middleware_entries(&configs, &input.host) +} + +fn parse_middleware_configs(value: ®orus::Value) -> Result> { + match value { + regorus::Value::Undefined => Ok(Vec::new()), + regorus::Value::Array(values) => Ok(values.to_vec()), + other => Err(miette::miette!( + "network_middlewares must be an array, got {other:?}" + )), + } +} + +fn global_middleware_entries(configs: &[regorus::Value], host: &str) -> Result> { + let mut entries = Vec::new(); + for config in configs { + if middleware_selector_matches(config, host)? { + entries.push(chain_entry_from_value(config)?); + } + } + Ok(entries) +} + +fn middleware_selector_matches(config: ®orus::Value, host: &str) -> Result { + let Some(selector) = get_field(config, "endpoints") else { + return Ok(false); + }; + let include_patterns = get_str_array(selector, "include"); + let exclude_patterns = get_str_array(selector, "exclude"); + let matches_include = include_patterns + .iter() + .try_fold(false, |matched, pattern| { + openshell_policy::middleware_host_matches(pattern, host) + .map(|matches| matched || matches) + .map_err(|error| miette::miette!(error)) + })?; + let matches_exclude = exclude_patterns + .iter() + .try_fold(false, |matched, pattern| { + openshell_policy::middleware_host_matches(pattern, host) + .map(|matches| matched || matches) + .map_err(|error| miette::miette!(error)) + })?; + Ok(matches_include && !matches_exclude) +} + +fn chain_entry_from_value(value: ®orus::Value) -> Result { + let name = get_str(value, "name").unwrap_or_default(); + let implementation = get_str(value, "middleware").unwrap_or_default(); + Ok(ChainEntry { + name, + implementation, + config: get_field(value, "config") + .map(regorus_value_to_struct) + .unwrap_or_default(), + on_error: openshell_supervisor_middleware::OnError::parse( + get_str(value, "on_error").as_deref().unwrap_or_default(), + )?, + }) +} + +fn get_field<'a>(val: &'a regorus::Value, key: &str) -> Option<&'a regorus::Value> { + let key_val = regorus::Value::String(key.into()); + match val { + regorus::Value::Object(map) => map.get(&key_val), + _ => None, + } +} + +fn regorus_value_to_struct(value: ®orus::Value) -> prost_types::Struct { + let regorus::Value::Object(map) = value else { + return prost_types::Struct::default(); + }; + prost_types::Struct { + fields: map + .iter() + .filter_map(|(key, value)| match key { + regorus::Value::String(key) => { + Some((key.to_string(), regorus_value_to_prost(value))) + } + _ => None, + }) + .collect(), + } +} + +fn regorus_value_to_prost(value: ®orus::Value) -> prost_types::Value { + use prost_types::{ListValue, Struct, Value, value::Kind}; + Value { + kind: Some(match value { + regorus::Value::Bool(value) => Kind::BoolValue(*value), + regorus::Value::Number(value) => Kind::NumberValue(value.as_f64().unwrap_or_default()), + regorus::Value::String(value) => Kind::StringValue(value.to_string()), + regorus::Value::Array(values) => Kind::ListValue(ListValue { + values: values.iter().map(regorus_value_to_prost).collect(), + }), + regorus::Value::Object(values) => Kind::StructValue(Struct { + fields: values + .iter() + .filter_map(|(key, value)| match key { + regorus::Value::String(key) => { + Some((key.to_string(), regorus_value_to_prost(value))) + } + _ => None, + }) + .collect(), + }), + _ => Kind::NullValue(0), + }), + } +} + fn parse_filesystem_policy(val: ®orus::Value) -> FilesystemPolicy { FilesystemPolicy { read_only: get_str_array(val, "read_only") @@ -735,6 +948,14 @@ fn preprocess_yaml_data(yaml_str: &str) -> Result { } // Validate BEFORE expanding presets (catches user errors like rules+access) + let middleware_errors = validate_middleware_policies(&data); + if !middleware_errors.is_empty() { + return Err(miette::miette!( + "middleware policy validation failed:\n{}", + middleware_errors.join("\n") + )); + } + let (errors, warnings) = crate::l7::validate_l7_policies(&data); for w in &warnings { openshell_ocsf::ocsf_emit!( @@ -955,6 +1176,137 @@ fn normalize_l7_rule_aliases( } } +fn validate_middleware_policies(data: &serde_json::Value) -> Vec { + let mut errors = Vec::new(); + let middlewares = data + .get("network_middlewares") + .and_then(serde_json::Value::as_array) + .map_or(&[][..], Vec::as_slice); + let mut names = HashSet::new(); + for mw in middlewares { + let name = mw + .get("name") + .and_then(serde_json::Value::as_str) + .unwrap_or_default(); + let implementation = mw + .get("middleware") + .and_then(serde_json::Value::as_str) + .unwrap_or_default(); + if name.is_empty() { + errors.push("network_middlewares entry has empty name".to_string()); + } else if !names.insert(name.to_string()) { + errors.push(format!("duplicate middleware config '{name}'")); + } + if implementation.is_empty() { + errors.push(format!( + "middleware config '{name}' has empty implementation" + )); + } + if implementation.starts_with("openshell/") + && implementation != openshell_supervisor_middleware::BUILTIN_SECRETS + { + errors.push(format!( + "middleware config '{name}' references unsupported built-in '{implementation}'" + )); + } + let on_error = mw + .get("on_error") + .and_then(serde_json::Value::as_str) + .unwrap_or_default(); + if !matches!(on_error, "" | "fail_closed" | "fail_open") { + errors.push(format!( + "middleware config '{name}' has invalid on_error '{on_error}'" + )); + } + + let Some(selector) = mw.get("endpoints") else { + errors.push(format!( + "middleware config '{name}' requires an endpoint selector" + )); + continue; + }; + let includes = json_string_array(selector.get("include")); + let excludes = json_string_array(selector.get("exclude")); + if includes.is_empty() { + errors.push(format!( + "middleware config '{name}' endpoint selector must include at least one host pattern" + )); + } + for pattern in includes.iter().chain(&excludes) { + if let Err(error) = + openshell_policy::middleware_host_matches(pattern, "validation.invalid") + { + errors.push(format!( + "middleware config '{name}' has invalid endpoint selector pattern '{pattern}': {error}" + )); + } + } + } + + let Some(policies) = data + .get("network_policies") + .and_then(serde_json::Value::as_object) + else { + return errors; + }; + + for (policy_name, policy) in policies { + for endpoint in policy + .get("endpoints") + .and_then(serde_json::Value::as_array) + .map_or(&[][..], Vec::as_slice) + { + let tls_skip = endpoint + .get("tls") + .and_then(serde_json::Value::as_str) + .is_some_and(|tls| tls == "skip"); + if tls_skip && global_selector_matches_any_middleware(middlewares, endpoint) { + errors.push(format!( + "network policy '{policy_name}' tls: skip endpoint matches a global middleware selector" + )); + } + } + } + errors +} + +fn json_string_array(value: Option<&serde_json::Value>) -> Vec { + value + .and_then(serde_json::Value::as_array) + .map(|values| { + values + .iter() + .filter_map(serde_json::Value::as_str) + .map(ToString::to_string) + .collect() + }) + .unwrap_or_default() +} + +fn global_selector_matches_any_middleware( + middlewares: &[serde_json::Value], + endpoint: &serde_json::Value, +) -> bool { + let host = endpoint + .get("host") + .and_then(serde_json::Value::as_str) + .unwrap_or_default(); + middlewares.iter().any(|mw| { + let Some(selector) = mw.get("endpoints") else { + return false; + }; + let includes = json_string_array(selector.get("include")); + let excludes = json_string_array(selector.get("exclude")); + !includes.is_empty() + && includes.iter().any(|pattern| { + openshell_policy::middleware_host_matches(pattern, host).unwrap_or(false) + }) + && !excludes.iter().any(|pattern| { + openshell_policy::middleware_host_matches(pattern, host).unwrap_or(false) + }) + }) +} + /// Resolve a policy binary path through the container's root filesystem. /// /// On Linux, `/proc//root/` provides access to the container's mount @@ -1341,14 +1693,40 @@ fn proto_to_opa_data_json(proto: &ProtoSandboxPolicy, entrypoint_pid: u32) -> St entries }) .collect(); - ( - key.clone(), - serde_json::json!({ - "name": rule.name, - "endpoints": endpoints, - "binaries": binaries, - }), - ) + let policy = serde_json::json!({ + "name": rule.name, + "endpoints": endpoints, + "binaries": binaries, + }); + (key.clone(), policy) + }) + .collect(); + + let network_middlewares: Vec = proto + .network_middlewares + .iter() + .map(|mw| { + let mut value = serde_json::json!({ + "name": mw.name, + "middleware": mw.middleware, + }); + if let Some(config) = &mw.config { + value["config"] = prost_struct_to_json(config); + } + if !mw.on_error.is_empty() { + value["on_error"] = mw.on_error.clone().into(); + } + if let Some(selector) = &mw.endpoints { + let mut endpoints = serde_json::json!({}); + if !selector.include.is_empty() { + endpoints["include"] = selector.include.clone().into(); + } + if !selector.exclude.is_empty() { + endpoints["exclude"] = selector.exclude.clone().into(); + } + value["endpoints"] = endpoints; + } + value }) .collect(); @@ -1357,10 +1735,37 @@ fn proto_to_opa_data_json(proto: &ProtoSandboxPolicy, entrypoint_pid: u32) -> St "landlock": landlock, "process": process, "network_policies": network_policies, + "network_middlewares": network_middlewares, }) .to_string() } +fn prost_struct_to_json(config: &prost_types::Struct) -> serde_json::Value { + serde_json::Value::Object( + config + .fields + .iter() + .map(|(key, value)| (key.clone(), prost_value_to_json(value))) + .collect(), + ) +} + +fn prost_value_to_json(value: &prost_types::Value) -> serde_json::Value { + match value.kind.as_ref() { + Some(prost_types::value::Kind::NullValue(_)) | None => serde_json::Value::Null, + Some(prost_types::value::Kind::BoolValue(value)) => serde_json::Value::Bool(*value), + Some(prost_types::value::Kind::NumberValue(value)) => serde_json::Number::from_f64(*value) + .map_or(serde_json::Value::Null, serde_json::Value::Number), + Some(prost_types::value::Kind::StringValue(value)) => { + serde_json::Value::String(value.clone()) + } + Some(prost_types::value::Kind::ListValue(value)) => { + serde_json::Value::Array(value.values.iter().map(prost_value_to_json).collect()) + } + Some(prost_types::value::Kind::StructValue(value)) => prost_struct_to_json(value), + } +} + #[cfg(test)] #[allow( clippy::needless_raw_string_hashes, @@ -1439,6 +1844,7 @@ mod tests { run_as_group: "sandbox".to_string(), }), network_policies, + network_middlewares: vec![], } } @@ -2564,6 +2970,7 @@ network_policies: let engine = OpaEngine { engine: Mutex::new(rego), generation: Arc::new(AtomicU64::new(0)), + middleware_runner: RwLock::new(ChainRunner::default()), }; let input = l7_websocket_graphql_input( "realtime.graphql.com", @@ -2781,6 +3188,7 @@ network_policies: run_as_group: "sandbox".to_string(), }), network_policies, + network_middlewares: vec![], }; let engine = OpaEngine::from_proto(&proto).expect("engine from proto"); @@ -2852,6 +3260,7 @@ network_policies: run_as_group: "sandbox".to_string(), }), network_policies, + network_middlewares: vec![], }; let engine = OpaEngine::from_proto(&proto).expect("engine from proto"); @@ -2924,6 +3333,7 @@ network_policies: run_as_group: "sandbox".to_string(), }), network_policies, + network_middlewares: vec![], }; let engine = OpaEngine::from_proto(&proto).expect("engine from proto"); @@ -3800,6 +4210,7 @@ network_policies: run_as_group: "sandbox".to_string(), }), network_policies, + network_middlewares: vec![], }; let engine = OpaEngine::from_proto(&proto).expect("engine from proto"); @@ -3857,6 +4268,7 @@ network_policies: run_as_group: "sandbox".to_string(), }), network_policies, + network_middlewares: vec![], }; let engine = OpaEngine::from_proto(&proto).expect("engine from proto"); @@ -3915,6 +4327,7 @@ network_policies: run_as_group: "sandbox".to_string(), }), network_policies, + network_middlewares: vec![], }; let engine = OpaEngine::from_proto(&proto).expect("engine from proto"); @@ -3975,6 +4388,7 @@ network_policies: run_as_group: "sandbox".to_string(), }), network_policies, + network_middlewares: vec![], }; let engine = OpaEngine::from_proto(&proto).expect("engine from proto"); @@ -4034,6 +4448,7 @@ network_policies: run_as_group: "sandbox".to_string(), }), network_policies, + network_middlewares: vec![], }; let engine = OpaEngine::from_proto(&proto).expect("engine from proto"); @@ -4983,6 +5398,7 @@ process: run_as_group: "sandbox".to_string(), }), network_policies, + network_middlewares: vec![], }; let engine = OpaEngine::from_proto(&proto).expect("engine from proto"); let input = NetworkInput { @@ -5037,6 +5453,7 @@ process: run_as_group: "sandbox".to_string(), }), network_policies, + network_middlewares: vec![], }; let engine = OpaEngine::from_proto(&proto).expect("engine from proto"); let input = NetworkInput { @@ -5107,6 +5524,7 @@ process: run_as_group: "sandbox".to_string(), }), network_policies, + network_middlewares: vec![], }; let engine = OpaEngine::from_proto(&proto).expect("Failed to create engine from proto"); @@ -5337,6 +5755,7 @@ network_policies: run_as_group: "sandbox".to_string(), }), network_policies, + network_middlewares: vec![], }; let engine = OpaEngine::from_proto(&proto).unwrap(); // Port 443 @@ -6279,6 +6698,7 @@ network_policies: path: link_path, ..Default::default() }], + ..Default::default() }, ); let proto = ProtoSandboxPolicy { @@ -6296,6 +6716,7 @@ network_policies: run_as_group: "sandbox".to_string(), }), network_policies, + network_middlewares: vec![], }; // Build engine with our PID (symlink resolution will work via /proc/self/root/) @@ -6356,6 +6777,7 @@ network_policies: path: link_path, ..Default::default() }], + ..Default::default() }, ); let proto = ProtoSandboxPolicy { @@ -6373,6 +6795,7 @@ network_policies: run_as_group: "sandbox".to_string(), }), network_policies, + network_middlewares: vec![], }; // Initial load at pid=0 — no symlink expansion @@ -6415,6 +6838,171 @@ network_policies: assert!(eval_l7(&engine, &input)); } + #[test] + fn middleware_chain_uses_matching_selector_declaration_order() { + let data = r#" +network_middlewares: + - name: global-redactor + middleware: openshell/secrets + endpoints: + include: ["api.example.com"] + - name: policy-redactor + middleware: openshell/secrets + endpoints: + include: ["api.example.com"] + - name: endpoint-redactor + middleware: openshell/secrets + endpoints: + include: ["api.example.com"] +network_policies: + api: + name: api + endpoints: + - host: api.example.com + port: 443 + protocol: rest + enforcement: enforce + rules: + - allow: { method: POST, path: "/v1/**" } + binaries: + - { path: /usr/bin/curl } +"#; + let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); + let input = NetworkInput { + host: "api.example.com".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/curl"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let (chain, _) = engine + .query_middleware_chain_with_generation(&input) + .unwrap(); + let names: Vec<_> = chain.iter().map(|entry| entry.name.as_str()).collect(); + assert_eq!( + names, + vec!["global-redactor", "policy-redactor", "endpoint-redactor"] + ); + } + + #[test] + fn middleware_policy_validation_rejects_bad_configs() { + let cases = [ + ( + "invalid on_error", + r#" +network_middlewares: + - name: redactor + middleware: openshell/secrets + on_error: maybe + endpoints: + include: ["api.example.com"] +"#, + "invalid on_error", + ), + ( + "duplicate names", + r#" +network_middlewares: + - name: redactor + middleware: openshell/secrets + endpoints: + include: ["api.example.com"] + - name: redactor + middleware: openshell/secrets + endpoints: + include: ["api.example.com"] +"#, + "duplicate middleware config 'redactor'", + ), + ( + "reserved builtin", + r#" +network_middlewares: + - name: sigv4 + middleware: openshell/sigv4 + endpoints: + include: ["api.example.com"] +"#, + "unsupported built-in", + ), + ( + "missing selector", + r#" +network_middlewares: + - name: redactor + middleware: openshell/secrets +"#, + "requires an endpoint selector", + ), + ( + "malformed selector", + r#" +network_middlewares: + - name: redactor + middleware: openshell/secrets + endpoints: + include: ["api[.example.com"] +"#, + "invalid host pattern", + ), + ( + "tls skip selector", + r#" +network_middlewares: + - name: redactor + middleware: openshell/secrets + endpoints: + include: ["api.example.com"] +network_policies: + api: + endpoints: + - host: api.example.com + port: 443 + tls: skip + binaries: + - { path: /usr/bin/curl } +"#, + "tls: skip", + ), + ]; + + for (name, data, expected) in cases { + let err = match OpaEngine::from_strings(TEST_POLICY, data) { + Ok(_) => panic!("{name}: expected policy validation failure"), + Err(err) => err.to_string(), + }; + assert!( + err.contains(expected), + "{name}: expected {expected:?} in {err:?}" + ); + } + } + + #[test] + fn from_proto_revalidates_middleware_policy() { + let mut policy = openshell_policy::restrictive_default_policy(); + policy + .network_middlewares + .push(openshell_core::proto::NetworkMiddlewareConfig { + name: "redactor".into(), + middleware: "openshell/secrets".into(), + endpoints: Some(openshell_core::proto::MiddlewareEndpointSelector { + include: vec!["api[.example.com".into()], + exclude: Vec::new(), + }), + ..Default::default() + }); + + let error = OpaEngine::from_proto(&policy) + .err() + .expect("supervisor must reject invalid effective middleware policy") + .to_string(); + assert!(error.contains("policy validation failed"), "{error}"); + assert!(error.contains("invalid host pattern"), "{error}"); + } + #[test] fn l7_head_denied_when_only_post_allowed() { let engine = OpaEngine::from_strings( diff --git a/crates/openshell-supervisor-network/src/proxy.rs b/crates/openshell-supervisor-network/src/proxy.rs index 0d2c8c025..8616c3b2c 100644 --- a/crates/openshell-supervisor-network/src/proxy.rs +++ b/crates/openshell-supervisor-network/src/proxy.rs @@ -1183,6 +1183,7 @@ async fn handle_tcp_connection( &mut tls_upstream, &ctx, &generation_guard, + Some(&opa_engine), ) .await } @@ -1288,6 +1289,7 @@ async fn handle_tcp_connection( &mut upstream, &ctx, &generation_guard, + Some(&opa_engine), ) .await { @@ -4182,6 +4184,76 @@ async fn handle_forward_proxy( } emit_forward_success_activity(activity_tx, l7_activity_pending); + let middleware_path = path.split_once('?').map_or(path.as_str(), |(path, _)| path); + let middleware_input = crate::opa::NetworkInput { + host: host_lc.clone(), + port, + binary_path: decision.binary.clone().unwrap_or_default(), + binary_sha256: String::new(), + ancestors: decision.ancestors.clone(), + cmdline_paths: decision.cmdline_paths.clone(), + }; + let (chain, generation) = + opa_engine.query_middleware_chain_with_generation(&middleware_input)?; + if generation != forward_generation_guard.captured_generation() { + emit_l7_tunnel_close_after_policy_change( + &host_lc, + port, + miette::miette!( + "policy changed before forward middleware evaluation [expected_generation:{} current_generation:{}]", + forward_generation_guard.captured_generation(), + generation, + ), + ); + respond( + client, + &build_json_error_response( + 403, + "Forbidden", + "policy_denied", + &format!("{method} {host_lc}:{port}{path} not permitted by policy"), + ), + ) + .await?; + return Ok(()); + } + if !chain.is_empty() { + let middleware_runner = opa_engine.middleware_runner()?; + let request = crate::l7::rest::request_from_buffered_http( + method, + middleware_path, + &upstream_target, + forward_request_bytes, + )?; + forward_request_bytes = match crate::l7::relay::apply_middleware_chain_for_scheme( + request, + client, + &l7_ctx, + &scheme, + chain, + &middleware_runner, + &forward_generation_guard, + ) + .await? + { + crate::l7::relay::MiddlewareApplyResult::Allowed(request) => request.raw_header, + crate::l7::relay::MiddlewareApplyResult::Denied(reason) => { + emit_activity_simple(activity_tx, true, "middleware"); + respond( + client, + &build_json_error_response( + 403, + "Forbidden", + "middleware_denied", + &format!("{method} {host_lc}:{port}{path} denied by middleware: {reason}"), + ), + ) + .await?; + return Ok(()); + } + }; + } + forward_request_bytes = match inject_token_grant_for_forward_request( method, &upstream_target, diff --git a/docs/reference/gateway-config.mdx b/docs/reference/gateway-config.mdx index 2aaa6e7b0..e9d2e579e 100644 --- a/docs/reference/gateway-config.mdx +++ b/docs/reference/gateway-config.mdx @@ -103,6 +103,14 @@ guest_tls_key = "/etc/openshell/certs/client-key.pem" grpc_rate_limit_requests = 120 grpc_rate_limit_window_seconds = 60 +# Local-development-only external supervisor middleware. The endpoint must be +# reachable from both the gateway and sandbox supervisors. +[[openshell.gateway.middleware]] +name = "local-content-guard" +endpoint = "http://host.openshell.internal:50051" +allow_insecure = true +max_body_bytes = 262144 + # Gateway listener TLS (distinct from the per-driver guest_tls_*). [openshell.gateway.tls] cert_path = "/etc/openshell/certs/gateway.pem" @@ -140,6 +148,30 @@ Local Docker, Podman, and VM gateways can also set `[openshell.gateway.mtls_auth `[openshell.gateway.auth] allow_unauthenticated_users = true` is an unsafe local-development and trusted-proxy escape hatch. It accepts user-facing CLI/API calls without OIDC or mTLS credentials while sandbox supervisors still authenticate with gateway-minted sandbox JWTs. Leave it false for shared and production gateways. +## Supervisor Middleware Services + + +Supervisor middleware is a research preview. Its policy and service contracts may change without compatibility guarantees. Use it only to prototype and evaluate middleware integrations. + + +Register operator-run supervisor middleware services with one or more `[[openshell.gateway.middleware]]` entries. Registration is static and operator-owned; changing it requires restarting the gateway. + +```toml +[[openshell.gateway.middleware]] +name = "local-content-guard" +endpoint = "http://host.openshell.internal:50051" +allow_insecure = true +max_body_bytes = 262144 +``` + +Each service implements the supervisor middleware gRPC contract and may expose multiple binding IDs through `Describe`. Policies reference those binding IDs, not the registration `name`. The gateway rejects duplicate binding IDs across services and prevents operator-run services from claiming the reserved `openshell/` namespace. + +The gateway connects to every registered service and validates `Describe` before it starts. The service must therefore be running before the gateway. Policy creation and full policy updates call `ValidateConfig`; an unavailable service or invalid middleware configuration rejects the policy before persistence. + +`max_body_bytes` is the operator limit for every binding exposed by the service. It must be greater than zero and no larger than each binding's advertised limit. OpenShell rejects an oversized value instead of silently clamping it. + +The service endpoint must use plaintext `http://`, and `allow_insecure = true` is required as an explicit acknowledgement that inspected request content is sent without transport encryption or peer authentication. TLS, authentication, health checks, and runtime registration are not supported. The endpoint must be reachable from both the gateway and sandbox supervisors; use `host.openshell.internal` or another shared address when both runtimes resolve it. + `image_pull_policy` is intentionally not a shared gateway key. Kubernetes and Docker use `Always`, `IfNotPresent`, or `Never`. Podman uses `always`, `missing`, `never`, or `newer`. Set it inside the relevant driver table. ## Driver References diff --git a/docs/reference/policy-schema.mdx b/docs/reference/policy-schema.mdx index 906d55d58..1d24b28df 100644 --- a/docs/reference/policy-schema.mdx +++ b/docs/reference/policy-schema.mdx @@ -20,6 +20,7 @@ filesystem_policy: { ... } landlock: { ... } process: { ... } network_policies: { ... } +network_middlewares: [ ... ] ``` | Field | Type | Required | Category | Description | @@ -29,6 +30,7 @@ network_policies: { ... } | `landlock` | object | No | Static | Configures Landlock LSM enforcement behavior. | | `process` | object | No | Static | Sets the user and group the agent process runs as. | | `network_policies` | map | No | Dynamic | Declares which binaries can reach which network endpoints. | +| `network_middlewares` | list | No | Dynamic | Selects ordered HTTP request middleware by destination host. | Static fields are set at sandbox creation time. Changing them requires destroying and recreating the sandbox. Dynamic fields can be updated on a running sandbox with `openshell policy update` for incremental merges or `openshell policy set` for full replacement, and take effect without restarting. @@ -468,7 +470,39 @@ Identifies an executable that is permitted to use the associated endpoints. |---|---|---|---| | `path` | string | Yes | Filesystem path to the executable. Supports glob patterns with `*` and `**`. For example, `/sandbox/.vscode-server/**` matches any executable under that directory tree. | -### Full Example +## Network Middleware + + +Supervisor middleware is a research preview. Its policy and service contracts may change without compatibility guarantees. Use it only to prototype and evaluate middleware integrations. + + +**Category:** Dynamic + +An ordered list of middleware configs selected after network and L7 policy admit an HTTP request. Middleware selection is independent of the network policy entry that admitted the request. Every matching config runs once in list order before provider credential injection. + +```yaml showLineNumbers={false} +network_middlewares: + - name: redact-secrets + middleware: openshell/secrets + config: + secrets: redact + on_error: fail_closed + endpoints: + include: ["*.example.com"] + exclude: ["trusted.example.com"] +``` + +| Field | Type | Required | Description | +|---|---|---|---| +| `name` | string | Yes | Policy-local config name. Names must be unique within the list. | +| `middleware` | string | Yes | Built-in or operator-registered binding ID. `openshell/` is reserved for built-ins. | +| `config` | object | No | Implementation-owned configuration validated by the selected middleware. | +| `on_error` | string | No | `fail_closed` denies the request when the stage fails; `fail_open` skips the failed stage. Defaults to `fail_closed`. | +| `endpoints` | object | Yes | Host selector with required non-empty `include` and optional `exclude` lists. Exclusions take precedence. | + +Host selectors use the same case-insensitive exact and DNS glob semantics as network endpoints. Middleware runs only on HTTP requests the supervisor parses. A selector that can require middleware on a `tls: skip` endpoint is rejected because OpenShell cannot inspect that traffic. + +## Full Example The following policy grants read-only GitHub API access and npm registry access: diff --git a/docs/sandboxes/policies.mdx b/docs/sandboxes/policies.mdx index ea8716422..ac089ac94 100644 --- a/docs/sandboxes/policies.mdx +++ b/docs/sandboxes/policies.mdx @@ -12,7 +12,7 @@ Use this page to apply and iterate policy changes on running sandboxes. For a fu ## Policy Structure -A policy has static sections `filesystem_policy`, `landlock`, and `process` that are locked at sandbox creation, and a dynamic section `network_policies` that is hot-reloadable on a running sandbox. +A policy has static sections `filesystem_policy`, `landlock`, and `process` that are locked at sandbox creation, and dynamic `network_policies` and `network_middlewares` sections that are hot-reloadable on a running sandbox. ```yaml wordWrap showLineNumbers={false} version: 1 @@ -44,6 +44,17 @@ network_policies: binaries: - path: /usr/bin/curl +# Dynamic: ordered middleware selected independently by admitted host. +network_middlewares: + - name: redact-secrets + middleware: openshell/secrets + config: + secrets: redact + on_error: fail_closed + endpoints: + include: ["api.example.com"] + exclude: [] + ``` Static sections are locked at sandbox creation. Changing them requires destroying and recreating the sandbox. @@ -57,6 +68,33 @@ Raw streams are connection-scoped and outside L7 live-reload guarantees. This in | `landlock` | Static | Configures Landlock LSM enforcement behavior. Set `compatibility` to `best_effort` (skip individual inaccessible paths while applying remaining rules) or `hard_requirement` (fail if any path is inaccessible or the required kernel ABI is unavailable). Refer to the [Policy Schema Reference](/reference/policy-schema#landlock) for the full behavior table. | | `process` | Static | Sets the OS-level identity for the agent process. `run_as_user` and `run_as_group` default to `sandbox`. Root (`root` or `0`) is rejected. The agent also runs with seccomp filters that block dangerous system calls. | | `network_policies` | Dynamic | Controls network access for ordinary outbound traffic from the sandbox. Each block has a name, a list of endpoints (host, port, protocol, and optional rules), and a list of binaries allowed to use those endpoints.
Every outbound connection except `https://inference.local` goes through the proxy, which queries the [policy engine](/about/how-it-works#core-components) with the destination and calling binary. A connection is allowed only when both match an entry in the same policy block.
For endpoints with `protocol: rest`, the proxy auto-detects TLS and terminates it so each HTTP request can be checked against that endpoint's `rules` (method and path). For endpoints with `protocol: websocket`, the proxy validates the RFC 6455 upgrade and evaluates `GET` rules for the handshake plus either `WEBSOCKET_TEXT` rules for raw client text messages or GraphQL operation rules for GraphQL-over-WebSocket messages. Set `websocket_credential_rewrite: true` only when a WebSocket or REST compatibility endpoint must keep placeholder credentials in sandbox-owned text frames and resolve them at the OpenShell relay boundary.
Endpoints without `protocol` allow the TCP stream through without inspecting payloads.
If no endpoint matches, the connection is denied. Configure managed inference separately through [Inference Routing](/sandboxes/inference-routing). | +| `network_middlewares` | Dynamic | Declares ordered HTTP request middleware configs. After network and L7 policy admit a request, OpenShell matches each config's host selectors independently and runs matching entries in declaration order before credential injection. | + +## Supervisor Middleware + + +Supervisor middleware is a research preview. Its policy and service contracts may change without compatibility guarantees. Use it only to prototype and evaluate middleware integrations. + + +Supervisor middleware can inspect, deny, or replace admitted HTTP request bodies before provider credentials are injected. Middleware selection is independent of the `network_policies` rule that admitted the request: each `network_middlewares` entry matches the destination host through `endpoints.include` and `endpoints.exclude`. + +```yaml +network_middlewares: + - name: redact-secrets + middleware: openshell/secrets + config: + secrets: redact + on_error: fail_closed + endpoints: + include: ["*.example.com"] + exclude: ["trusted.example.com"] +``` + +Matching entries run once each in top-level declaration order. Config names must be unique. Different config names may use the same implementation and run as distinct stages. `exclude` takes precedence over `include`. + +`openshell/secrets` is built into the supervisor. Operator-provided binding IDs must be registered before a policy can reference them; see [Supervisor Middleware Services](/reference/gateway-config#supervisor-middleware-services). The gateway calls the implementation's `ValidateConfig` before accepting the policy. + +`on_error` defaults to `fail_closed`. Use `fail_open` only when skipping a failed middleware is acceptable. Middleware applies only to HTTP traffic the supervisor can parse and inspect; policy validation rejects a required selector that can cover a `tls: skip` endpoint. ## Baseline Filesystem Paths diff --git a/proto/middleware.proto b/proto/middleware.proto new file mode 100644 index 000000000..2944227d8 --- /dev/null +++ b/proto/middleware.proto @@ -0,0 +1,97 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +syntax = "proto3"; + +package openshell.middleware.v1; + +import "google/protobuf/empty.proto"; +import "google/protobuf/struct.proto"; + +service SupervisorMiddleware { + rpc Describe(google.protobuf.Empty) returns (MiddlewareManifest); + rpc ValidateConfig(ValidateConfigRequest) returns (ValidateConfigResponse); + rpc EvaluateHttpRequest(HttpRequestEvaluation) returns (HttpRequestResult); +} + +message MiddlewareManifest { + string api_version = 1; + string name = 2; + string service_version = 3; + repeated MiddlewareBinding bindings = 4; +} + +message MiddlewareBinding { + string id = 1; + string operation = 2; + string phase = 3; + // Maximum request or replacement body this binding can process. + uint64 max_body_bytes = 4; +} + +message ValidateConfigRequest { + string api_version = 1; + string binding_id = 2; + google.protobuf.Struct config = 3; +} + +message ValidateConfigResponse { + bool valid = 1; + string reason = 2; +} + +message HttpRequestEvaluation { + string api_version = 1; + string binding_id = 2; + string phase = 3; + RequestContext context = 4; + google.protobuf.Struct config = 5; + HttpRequestTarget target = 6; + map headers = 7; + bytes body = 8; +} + +message RequestContext { + string request_id = 1; + string sandbox_id = 2; + Process originating_process = 3; +} + +message HttpRequestTarget { + string scheme = 1; + string host = 2; + uint32 port = 3; + string method = 4; + string path = 5; + string query = 6; +} + +message Process { + string binary = 1; + uint32 pid = 2; + repeated string ancestors = 3; +} + +enum Decision { + DECISION_UNSPECIFIED = 0; + DECISION_ALLOW = 1; + DECISION_DENY = 2; +} + +message Finding { + string type = 1; + string label = 2; + uint32 count = 3; + string confidence = 4; + string severity = 5; +} + +message HttpRequestResult { + Decision decision = 1; + string reason = 2; + bytes body = 3; + bool has_body = 4; + map add_headers = 5; + repeated Finding findings = 6; + map metadata = 7; +} diff --git a/proto/sandbox.proto b/proto/sandbox.proto index 8a5a59333..644fd86cb 100644 --- a/proto/sandbox.proto +++ b/proto/sandbox.proto @@ -5,6 +5,8 @@ syntax = "proto3"; package openshell.sandbox.v1; +import "google/protobuf/struct.proto"; + // Sandbox-supervisor configuration and policy messages. // // Conventions: @@ -25,6 +27,8 @@ message SandboxPolicy { ProcessPolicy process = 4; // Network access policies keyed by name (e.g. "claude_code", "gitlab"). map network_policies = 5; + // Reusable supervisor middleware configs for network egress. + repeated NetworkMiddlewareConfig network_middlewares = 6; } // Filesystem access policy. @@ -61,6 +65,25 @@ message NetworkPolicyRule { repeated NetworkBinary binaries = 3; } +// A reusable middleware config selected for admitted egress by host. +message NetworkMiddlewareConfig { + // Policy-local config name. + string name = 1; + // Built-in or registered middleware implementation name. + string middleware = 2; + // Service-specific configuration. + google.protobuf.Struct config = 3; + // Failure behavior: "fail_closed" (default) or "fail_open". + string on_error = 4; + // Host selector controlling which admitted destinations use this config. + MiddlewareEndpointSelector endpoints = 5; +} + +message MiddlewareEndpointSelector { + repeated string include = 1; + repeated string exclude = 2; +} + // A network endpoint (host + port) with optional L7 inspection config. message NetworkEndpoint { // Hostname or host glob pattern. Exact match is case-insensitive. @@ -329,4 +352,20 @@ message GetSandboxConfigResponse { // Fingerprint for provider credential inputs attached to this sandbox. // Changes when attached provider names or attached provider records change. uint64 provider_env_revision = 8; + // Operator-registered supervisor middleware services required by the + // effective policy. Built-in middleware is not included. + repeated SupervisorMiddlewareService supervisor_middleware_services = 9; +} + +// Connection details for one operator-registered supervisor middleware service. +// V1 supports only explicitly enabled plaintext gRPC for local development. +message SupervisorMiddlewareService { + // Operator-facing registration name used for diagnostics. + string name = 1; + // gRPC endpoint reachable from the sandbox supervisor. + string endpoint = 2; + // Explicit acknowledgement that request content is sent without TLS. + bool allow_insecure = 3; + // Operator-owned body limit applied to every binding exposed by the service. + uint64 max_body_bytes = 4; }