From 6ef949069588d8b9677cfa15dfe4b5a82b4d7779 Mon Sep 17 00:00:00 2001 From: Piotr Mlocek Date: Thu, 25 Jun 2026 17:05:52 -0700 Subject: [PATCH 01/27] feat(middleware): add in-process supervisor middleware Signed-off-by: Piotr Mlocek --- Cargo.lock | 15 + Cargo.toml | 1 + crates/openshell-core/src/proto/mod.rs | 14 + crates/openshell-policy/Cargo.toml | 1 + crates/openshell-policy/src/compose.rs | 1 + crates/openshell-policy/src/lib.rs | 220 +++++- crates/openshell-policy/src/merge.rs | 24 + crates/openshell-providers/src/profiles.rs | 2 + .../Cargo.toml | 25 + .../src/builtins/mod.rs | 4 + .../src/builtins/secrets.rs | 83 +++ .../src/lib.rs | 437 ++++++++++++ .../src/service.rs | 75 ++ .../openshell-supervisor-network/Cargo.toml | 2 + .../data/sandbox-policy.rego | 22 +- .../src/l7/relay.rs | 478 ++++++++++++- .../src/l7/rest.rs | 122 ++++ .../openshell-supervisor-network/src/opa.rs | 638 +++++++++++++++++- .../src/policy_local.rs | 5 + .../openshell-supervisor-network/src/proxy.rs | 2 + proto/middleware.proto | 95 +++ proto/sandbox.proto | 29 +- 22 files changed, 2278 insertions(+), 17 deletions(-) create mode 100644 crates/openshell-supervisor-middleware/Cargo.toml create mode 100644 crates/openshell-supervisor-middleware/src/builtins/mod.rs create mode 100644 crates/openshell-supervisor-middleware/src/builtins/secrets.rs create mode 100644 crates/openshell-supervisor-middleware/src/lib.rs create mode 100644 crates/openshell-supervisor-middleware/src/service.rs create mode 100644 proto/middleware.proto diff --git a/Cargo.lock b/Cargo.lock index 13b670f55..6f664acba 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3771,6 +3771,7 @@ version = "0.0.0" dependencies = [ "miette", "openshell-core", + "prost-types", "serde", "serde_json", "serde_yml", @@ -3926,6 +3927,18 @@ dependencies = [ "syn 2.0.117", ] +[[package]] +name = "openshell-supervisor-middleware" +version = "0.0.0" +dependencies = [ + "miette", + "openshell-core", + "prost-types", + "regex", + "tokio", + "tonic", +] + [[package]] name = "openshell-supervisor-network" version = "0.0.0" @@ -3948,6 +3961,8 @@ dependencies = [ "openshell-ocsf", "openshell-policy", "openshell-router", + "openshell-supervisor-middleware", + "prost-types", "rcgen", "regorus", "reqwest 0.12.28", diff --git a/Cargo.toml b/Cargo.toml index f450cd5c8..fd3641d68 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -74,6 +74,7 @@ serde_yml = "0.0.12" toml = "0.8" apollo-parser = "0.8.5" tower-mcp-types = "0.12.0" +regex = "1" # HTTP client reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls-native-roots"] } diff --git a/crates/openshell-core/src/proto/mod.rs b/crates/openshell-core/src/proto/mod.rs index 08b062d2e..1ac6fc94c 100644 --- a/crates/openshell-core/src/proto/mod.rs +++ b/crates/openshell-core/src/proto/mod.rs @@ -79,8 +79,22 @@ pub mod inference { } } +#[allow( + clippy::all, + clippy::pedantic, + clippy::nursery, + unused_qualifications, + rust_2018_idioms +)] +pub mod middleware { + pub mod v1 { + include!(concat!(env!("OUT_DIR"), "/openshell.middleware.v1.rs")); + } +} + pub use datamodel::v1::*; pub use inference::v1::*; +pub use middleware::v1::*; pub use openshell::*; pub use sandbox::v1::*; pub use test::ObjectForTest; diff --git a/crates/openshell-policy/Cargo.toml b/crates/openshell-policy/Cargo.toml index 16719de13..50bea5b32 100644 --- a/crates/openshell-policy/Cargo.toml +++ b/crates/openshell-policy/Cargo.toml @@ -12,6 +12,7 @@ repository.workspace = true [dependencies] openshell-core = { path = "../openshell-core", default-features = false } +prost-types = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } serde_yml = { workspace = true } diff --git a/crates/openshell-policy/src/compose.rs b/crates/openshell-policy/src/compose.rs index 7ca8584d9..1ad0d4617 100644 --- a/crates/openshell-policy/src/compose.rs +++ b/crates/openshell-policy/src/compose.rs @@ -115,6 +115,7 @@ mod tests { ..Default::default() }], binaries: Vec::new(), + middleware: Vec::new(), } } diff --git a/crates/openshell-policy/src/lib.rs b/crates/openshell-policy/src/lib.rs index f1721146e..0aa43c30d 100644 --- a/crates/openshell-policy/src/lib.rs +++ b/crates/openshell-policy/src/lib.rs @@ -19,8 +19,8 @@ use std::path::Path; use miette::{IntoDiagnostic, Result, WrapErr}; use openshell_core::proto::{ FilesystemPolicy, GraphqlOperation, L7Allow, L7DenyRule, L7QueryMatcher, L7Rule, - LandlockPolicy, McpOptions, NetworkBinary, NetworkEndpoint, NetworkPolicyRule, ProcessPolicy, - SandboxPolicy, + LandlockPolicy, MiddlewareEndpointSelector, NetworkBinary, NetworkEndpoint, + NetworkMiddlewareConfig, NetworkPolicyRule, ProcessPolicy, SandboxPolicy, McpOptions, }; use serde::{Deserialize, Serialize}; @@ -49,6 +49,8 @@ struct PolicyFile { process: Option, #[serde(default, skip_serializing_if = "BTreeMap::is_empty")] network_policies: BTreeMap, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + network_middlewares: Vec, } #[derive(Debug, Serialize, Deserialize)] @@ -87,6 +89,30 @@ struct NetworkPolicyRuleDef { endpoints: Vec, #[serde(default, skip_serializing_if = "Vec::is_empty")] binaries: Vec, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + middleware: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(deny_unknown_fields)] +struct NetworkMiddlewareConfigDef { + name: String, + middleware: String, + #[serde(default, skip_serializing_if = "BTreeMap::is_empty")] + config: BTreeMap, + #[serde(default, skip_serializing_if = "String::is_empty")] + on_error: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + endpoints: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(deny_unknown_fields)] +struct MiddlewareEndpointSelectorDef { + #[serde(default, skip_serializing_if = "Vec::is_empty")] + include: Vec, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + exclude: Vec, } #[derive(Debug, Serialize, Deserialize)] @@ -148,6 +174,8 @@ struct NetworkEndpointDef { json_rpc: Option, #[serde(default, skip_serializing_if = "Option::is_none")] mcp: Option, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + middleware: Vec, } // Signature dictated by serde's `skip_serializing_if`, which requires `&T`. @@ -672,6 +700,21 @@ fn yaml_mcp_method( } fn to_proto(raw: PolicyFile) -> SandboxPolicy { + let network_middlewares = raw + .network_middlewares + .into_iter() + .map(|mw| NetworkMiddlewareConfig { + name: mw.name, + middleware: mw.middleware, + config: Some(json_map_to_struct(mw.config)), + on_error: mw.on_error, + endpoints: mw.endpoints.map(|selector| MiddlewareEndpointSelector { + include: selector.include, + exclude: selector.exclude, + }), + }) + .collect(); + let network_policies = raw .network_policies .into_iter() @@ -745,6 +788,7 @@ fn to_proto(raw: PolicyFile) -> SandboxPolicy { signing_region: e.signing_region, json_rpc_max_body_bytes: json_rpc_max_body_bytes(&e.json_rpc, &e.mcp), mcp: mcp_options(&e.mcp), + middleware: e.middleware, } }) .collect(), @@ -756,6 +800,7 @@ fn to_proto(raw: PolicyFile) -> SandboxPolicy { ..Default::default() }) .collect(), + middleware: rule.middleware, }; (key, proto_rule) }) @@ -776,6 +821,7 @@ fn to_proto(raw: PolicyFile) -> SandboxPolicy { run_as_group: p.run_as_group, }), network_policies, + network_middlewares, } } @@ -892,6 +938,7 @@ fn from_proto(policy: &SandboxPolicy) -> PolicyFile { signing_region: e.signing_region.clone(), json_rpc, mcp, + middleware: e.middleware.clone(), } }) .collect(), @@ -903,17 +950,103 @@ fn from_proto(policy: &SandboxPolicy) -> PolicyFile { harness: false, }) .collect(), + middleware: rule.middleware.clone(), }; (key.clone(), yaml_rule) }) .collect(); + let network_middlewares = policy + .network_middlewares + .iter() + .map(|mw| NetworkMiddlewareConfigDef { + name: mw.name.clone(), + middleware: mw.middleware.clone(), + config: mw + .config + .as_ref() + .map(struct_to_json_map) + .unwrap_or_default(), + on_error: mw.on_error.clone(), + endpoints: mw + .endpoints + .as_ref() + .map(|selector| MiddlewareEndpointSelectorDef { + include: selector.include.clone(), + exclude: selector.exclude.clone(), + }), + }) + .collect(); + PolicyFile { version: policy.version, filesystem_policy, landlock, process, network_policies, + network_middlewares, + } +} + +fn json_map_to_struct(map: BTreeMap) -> prost_types::Struct { + prost_types::Struct { + fields: map + .into_iter() + .map(|(key, value)| (key, json_to_protobuf_value(value))) + .collect(), + } +} + +fn json_to_protobuf_value(value: serde_json::Value) -> prost_types::Value { + use prost_types::{ListValue, Struct, Value, value::Kind}; + Value { + kind: Some(match value { + serde_json::Value::Null => Kind::NullValue(0), + serde_json::Value::Bool(value) => Kind::BoolValue(value), + serde_json::Value::Number(value) => { + Kind::NumberValue(value.as_f64().unwrap_or_default()) + } + serde_json::Value::String(value) => Kind::StringValue(value), + serde_json::Value::Array(values) => Kind::ListValue(ListValue { + values: values.into_iter().map(json_to_protobuf_value).collect(), + }), + serde_json::Value::Object(values) => Kind::StructValue(Struct { + fields: values + .into_iter() + .map(|(key, value)| (key, json_to_protobuf_value(value))) + .collect(), + }), + }), + } +} + +fn struct_to_json_map(config: &prost_types::Struct) -> BTreeMap { + config + .fields + .iter() + .map(|(key, value)| (key.clone(), protobuf_value_to_json(value))) + .collect() +} + +fn protobuf_value_to_json(value: &prost_types::Value) -> serde_json::Value { + match value.kind.as_ref() { + Some(prost_types::value::Kind::NullValue(_)) | None => serde_json::Value::Null, + Some(prost_types::value::Kind::BoolValue(value)) => serde_json::Value::Bool(*value), + Some(prost_types::value::Kind::NumberValue(value)) => serde_json::Number::from_f64(*value) + .map_or(serde_json::Value::Null, serde_json::Value::Number), + Some(prost_types::value::Kind::StringValue(value)) => { + serde_json::Value::String(value.clone()) + } + Some(prost_types::value::Kind::ListValue(value)) => { + serde_json::Value::Array(value.values.iter().map(protobuf_value_to_json).collect()) + } + Some(prost_types::value::Kind::StructValue(value)) => serde_json::Value::Object( + value + .fields + .iter() + .map(|(key, value)| (key.clone(), protobuf_value_to_json(value))) + .collect(), + ), } } @@ -1064,6 +1197,7 @@ pub fn restrictive_default_policy() -> SandboxPolicy { run_as_group: "sandbox".into(), }), network_policies: HashMap::new(), + network_middlewares: vec![], } } @@ -1438,6 +1572,87 @@ network_policies: assert_eq!(proto2.network_policies["my_api"].name, "my-custom-api-name"); } + #[test] + fn round_trip_preserves_network_middlewares() { + let yaml = r#" +version: 1 +network_middlewares: + - name: global-redactor + middleware: openshell/secrets + on_error: fail_open + endpoints: + include: ["api.example.com", "*.service.test"] + exclude: ["internal.example.com"] + config: + secrets: ["api_key", "authorization"] + service: + mode: redact + max_matches: 2 + - name: endpoint-redactor + middleware: openshell/secrets +network_policies: + api: + name: api + middleware: ["global-redactor"] + endpoints: + - host: api.example.com + port: 443 + protocol: rest + middleware: ["endpoint-redactor"] + binaries: + - path: /usr/bin/curl +"#; + let proto = parse_sandbox_policy(yaml).expect("parse failed"); + assert_eq!(proto.network_middlewares.len(), 2); + assert_eq!(proto.network_middlewares[0].name, "global-redactor"); + assert_eq!(proto.network_middlewares[0].middleware, "openshell/secrets"); + assert_eq!(proto.network_middlewares[0].on_error, "fail_open"); + assert_eq!( + proto.network_middlewares[0] + .endpoints + .as_ref() + .expect("selector") + .include, + vec!["api.example.com", "*.service.test"] + ); + assert_eq!( + proto.network_middlewares[0] + .endpoints + .as_ref() + .expect("selector") + .exclude, + vec!["internal.example.com"] + ); + assert!( + proto.network_middlewares[0] + .config + .as_ref() + .expect("config") + .fields + .contains_key("service") + ); + assert_eq!( + proto.network_policies["api"].middleware, + vec!["global-redactor"] + ); + assert_eq!( + proto.network_policies["api"].endpoints[0].middleware, + vec!["endpoint-redactor"] + ); + + let yaml_out = serialize_sandbox_policy(&proto).expect("serialize failed"); + let reparsed = parse_sandbox_policy(&yaml_out).expect("re-parse failed"); + assert_eq!(reparsed.network_middlewares, proto.network_middlewares); + assert_eq!( + reparsed.network_policies["api"].middleware, + vec!["global-redactor"] + ); + assert_eq!( + reparsed.network_policies["api"].endpoints[0].middleware, + vec!["endpoint-redactor"] + ); + } + #[test] fn restrictive_default_has_no_network_policies() { let policy = restrictive_default_policy(); @@ -1753,6 +1968,7 @@ network_policies: filesystem: None, landlock: None, network_policies: HashMap::new(), + network_middlewares: Vec::new(), }; assert!(validate_sandbox_policy(&policy).is_ok()); } diff --git a/crates/openshell-policy/src/merge.rs b/crates/openshell-policy/src/merge.rs index 04f390198..1c63e6ebc 100644 --- a/crates/openshell-policy/src/merge.rs +++ b/crates/openshell-policy/src/merge.rs @@ -989,6 +989,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }, ); @@ -1007,6 +1008,7 @@ mod tests { path: "/usr/bin/gh".to_string(), ..Default::default() }], + ..Default::default() }; let result = merge_policy( @@ -1035,6 +1037,7 @@ mod tests { name: "existing".to_string(), endpoints: vec![endpoint("api.github.com", 443)], binaries: vec![advisor_binary("/usr/bin/curl")], + ..Default::default() }, ); @@ -1045,6 +1048,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; let result = merge_policy( @@ -1076,6 +1080,7 @@ mod tests { ..Default::default() }, ], + ..Default::default() }; let result = merge_policy( @@ -1107,6 +1112,7 @@ mod tests { path: "/usr/bin/python".to_string(), ..Default::default() }], + ..Default::default() }, ); @@ -1120,6 +1126,7 @@ mod tests { ..Default::default() }], binaries: vec![advisor_binary("/usr/bin/python")], + ..Default::default() }; let result = merge_policy( @@ -1447,6 +1454,7 @@ mod tests { path: "/usr/bin/gh".to_string(), ..Default::default() }], + ..Default::default() }, ); @@ -1471,6 +1479,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; let merged = merge_policy( @@ -1494,6 +1503,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; // Merge an *unrelated* rule for a different host. The proposed rule @@ -1524,6 +1534,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; let mut policy = restrictive_default_policy(); @@ -1536,6 +1547,7 @@ mod tests { path: "/usr/bin/git".to_string(), ..Default::default() }], + ..Default::default() }, ); @@ -1567,6 +1579,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; // Endpoint exists in the policy but with a *different* binary. The @@ -1582,6 +1595,7 @@ mod tests { path: "/usr/bin/git".to_string(), ..Default::default() }], + ..Default::default() }, ); @@ -1618,6 +1632,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; let mut policy = restrictive_default_policy(); @@ -1637,6 +1652,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }, ); @@ -1664,6 +1680,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; let mut policy = restrictive_default_policy(); @@ -1686,6 +1703,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }, ); @@ -1709,6 +1727,7 @@ mod tests { path: "/usr/bin/git".to_string(), ..Default::default() }], + ..Default::default() }; let merged = merge_policy( @@ -1733,6 +1752,7 @@ mod tests { name: "any_binary_rule".to_string(), endpoints: vec![endpoint("api.github.com", 443)], binaries: vec![], + ..Default::default() }; let mut policy = restrictive_default_policy(); @@ -1745,6 +1765,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }, ); @@ -1802,6 +1823,7 @@ mod tests { path: "/usr/bin/gh".to_string(), ..Default::default() }], + ..Default::default() }; let composed = compose_effective_policy( &SandboxPolicy::default(), @@ -1833,6 +1855,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; let result = merge_policy( composed, @@ -1901,6 +1924,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; let result = merge_policy( policy, diff --git a/crates/openshell-providers/src/profiles.rs b/crates/openshell-providers/src/profiles.rs index ddfbcaf7d..1eb1b54d2 100644 --- a/crates/openshell-providers/src/profiles.rs +++ b/crates/openshell-providers/src/profiles.rs @@ -450,6 +450,7 @@ impl ProviderTypeProfile { NetworkPolicyRule { name: rule_name.to_string(), endpoints: self.endpoints.iter().map(endpoint_to_proto).collect(), + middleware: Vec::new(), binaries: self.binaries.iter().map(binary_to_proto).collect(), } } @@ -787,6 +788,7 @@ fn endpoint_to_proto(endpoint: &EndpointProfile) -> NetworkEndpoint { request_body_credential_rewrite: endpoint.request_body_credential_rewrite, advisor_proposed: false, persisted_queries: endpoint.persisted_queries.clone(), + middleware: Vec::new(), graphql_persisted_queries: endpoint .graphql_persisted_queries .iter() diff --git a/crates/openshell-supervisor-middleware/Cargo.toml b/crates/openshell-supervisor-middleware/Cargo.toml new file mode 100644 index 000000000..fdaeb2e82 --- /dev/null +++ b/crates/openshell-supervisor-middleware/Cargo.toml @@ -0,0 +1,25 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +[package] +name = "openshell-supervisor-middleware" +description = "In-process supervisor middleware contract and built-ins for OpenShell" +version.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true +rust-version.workspace = true + +[dependencies] +openshell-core = { path = "../openshell-core" } + +miette = { workspace = true } +prost-types = { workspace = true } +regex = { workspace = true } +tonic = { workspace = true } + +[dev-dependencies] +tokio = { workspace = true } + +[lints] +workspace = true diff --git a/crates/openshell-supervisor-middleware/src/builtins/mod.rs b/crates/openshell-supervisor-middleware/src/builtins/mod.rs new file mode 100644 index 000000000..60572d3e8 --- /dev/null +++ b/crates/openshell-supervisor-middleware/src/builtins/mod.rs @@ -0,0 +1,4 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +pub(crate) mod secrets; diff --git a/crates/openshell-supervisor-middleware/src/builtins/secrets.rs b/crates/openshell-supervisor-middleware/src/builtins/secrets.rs new file mode 100644 index 000000000..6c94eb439 --- /dev/null +++ b/crates/openshell-supervisor-middleware/src/builtins/secrets.rs @@ -0,0 +1,83 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use std::collections::HashMap; + +use miette::{Result, miette}; +use openshell_core::proto::{Decision, Finding, HttpRequestEvaluation, HttpRequestResult}; +use regex::Regex; + +use crate::BUILTIN_SECRETS; + +pub(crate) fn validate_config(config: &prost_types::Struct) -> Result<()> { + let mode = config + .fields + .get("secrets") + .and_then(|value| match value.kind.as_ref() { + Some(prost_types::value::Kind::StringValue(value)) => Some(value.as_str()), + _ => None, + }) + .unwrap_or("redact"); + if mode != "redact" { + return Err(miette!( + "{} only supports config.secrets: redact in phase 1", + BUILTIN_SECRETS + )); + } + Ok(()) +} + +pub(crate) fn evaluate_http_request( + evaluation: &HttpRequestEvaluation, +) -> Result { + let default_config = prost_types::Struct::default(); + validate_config(evaluation.config.as_ref().unwrap_or(&default_config))?; + let text = String::from_utf8(evaluation.body.clone()) + .map_err(|_| miette!("{} requires UTF-8 request bodies", BUILTIN_SECRETS))?; + let (body, count) = redact_common_secrets(&text)?; + let mut result = HttpRequestResult { + decision: Decision::Allow as i32, + reason: String::new(), + body: body.into_bytes(), + has_body: count > 0, + add_headers: HashMap::new(), + findings: Vec::new(), + metadata: HashMap::new(), + }; + if count > 0 { + result.findings.push(Finding { + r#type: "secret.common".into(), + label: "common secret pattern".into(), + count, + confidence: "medium".into(), + severity: "medium".into(), + }); + result + .metadata + .insert("secrets_redacted".into(), count.to_string()); + } + Ok(result) +} + +fn redact_common_secrets(input: &str) -> Result<(String, u32)> { + let patterns = [ + r#"(?i)(api[_-]?key|access[_-]?token|secret|password)(["']?\s*[:=]\s*["'])[^"',\s}]+(["']?)"#, + r#"(sk-[A-Za-z0-9_-]{16,})"#, + ]; + let mut output = input.to_string(); + let mut count = 0u32; + for pattern in patterns { + let regex = Regex::new(pattern).map_err(|e| miette!("{e}"))?; + count = count.saturating_add(regex.find_iter(&output).count() as u32); + output = regex + .replace_all(&output, |captures: ®ex::Captures<'_>| { + if captures.len() >= 4 { + format!("{}{}[REDACTED]{}", &captures[1], &captures[2], &captures[3]) + } else { + "[REDACTED]".to_string() + } + }) + .into_owned(); + } + Ok((output, count)) +} diff --git a/crates/openshell-supervisor-middleware/src/lib.rs b/crates/openshell-supervisor-middleware/src/lib.rs new file mode 100644 index 000000000..7d9161fcf --- /dev/null +++ b/crates/openshell-supervisor-middleware/src/lib.rs @@ -0,0 +1,437 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! In-process supervisor middleware chain execution. + +mod builtins; +mod service; + +use std::collections::{BTreeMap, HashMap, HashSet}; +use std::sync::Arc; + +use miette::{Result, miette}; +pub use service::InProcessMiddlewareService; + +use openshell_core::proto::middleware::v1::supervisor_middleware_server::SupervisorMiddleware; +use openshell_core::proto::{ + Decision, Finding, HttpRequestEvaluation, HttpRequestTarget, NetworkMiddlewareConfig, Process, + RequestContext, +}; +use tonic::Request; + +pub const API_VERSION: &str = "openshell.middleware.v1"; +pub const HTTP_REQUEST_OPERATION: &str = "HttpRequest"; +pub const PRE_CREDENTIALS_PHASE: &str = "pre_credentials"; +pub const BUILTIN_SECRETS: &str = "openshell/secrets"; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum OnError { + FailClosed, + FailOpen, +} + +impl OnError { + pub fn parse(value: &str) -> Result { + match value { + "" | "fail_closed" => Ok(Self::FailClosed), + "fail_open" => Ok(Self::FailOpen), + other => Err(miette!( + "invalid middleware on_error '{other}', expected fail_closed or fail_open" + )), + } + } +} + +#[derive(Debug, Clone)] +pub struct ChainEntry { + pub name: String, + pub implementation: String, + pub config: prost_types::Struct, + pub on_error: OnError, +} + +impl TryFrom<&NetworkMiddlewareConfig> for ChainEntry { + type Error = miette::Report; + + fn try_from(value: &NetworkMiddlewareConfig) -> Result { + if value.name.is_empty() { + return Err(miette!("middleware config name cannot be empty")); + } + if value.middleware.is_empty() { + return Err(miette!( + "middleware config '{}' must name an implementation", + value.name + )); + } + Ok(Self { + name: value.name.clone(), + implementation: value.middleware.clone(), + config: value.config.clone().unwrap_or_default(), + on_error: OnError::parse(&value.on_error)?, + }) + } +} + +#[derive(Debug, Clone)] +pub struct HttpRequestInput { + pub request_id: String, + pub sandbox_id: String, + pub binary: String, + pub pid: u32, + pub ancestors: Vec, + pub scheme: String, + pub host: String, + pub port: u16, + pub method: String, + pub path: String, + pub query: String, + pub headers: BTreeMap, + pub body: Vec, +} + +#[derive(Debug, Clone)] +pub struct ChainOutcome { + pub allowed: bool, + pub reason: String, + pub body: Vec, + pub added_headers: BTreeMap, + pub findings: Vec, + pub metadata: BTreeMap>, + pub applied: Vec, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct NamespacedFinding { + pub middleware: String, + pub finding: Finding, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct MiddlewareInvocation { + pub name: String, + pub implementation: String, + pub decision: Decision, + pub transformed: bool, +} + +#[derive(Clone)] +pub struct ChainRunner { + service: Arc, +} + +impl Default for ChainRunner { + fn default() -> Self { + Self::new(Arc::new(InProcessMiddlewareService)) + } +} + +impl ChainRunner { + pub fn new(service: Arc) -> Self { + Self { service } + } + + pub async fn evaluate( + &self, + entries: &[ChainEntry], + input: HttpRequestInput, + ) -> Result { + let mut headers = input.headers.clone(); + let mut body = input.body.clone(); + let mut added_headers = BTreeMap::new(); + let mut findings = Vec::new(); + let mut metadata = BTreeMap::new(); + let mut applied = Vec::new(); + + for entry in entries { + let evaluation = build_evaluation(entry, &input, &headers, &body); + let result = match self + .service + .evaluate_http_request(Request::new(evaluation)) + .await + { + Ok(result) => result.into_inner(), + Err(err) => match entry.on_error { + OnError::FailOpen => { + applied.push(MiddlewareInvocation { + name: entry.name.clone(), + implementation: entry.implementation.clone(), + decision: Decision::Allow, + transformed: false, + }); + continue; + } + OnError::FailClosed => { + return Ok(ChainOutcome { + allowed: false, + reason: format!("middleware_failed: {}", safe_reason(&err.to_string())), + body, + added_headers, + findings, + metadata, + applied, + }); + } + }, + }; + + validate_header_mutations(&headers, &result.add_headers)?; + for (name, value) in &result.add_headers { + headers.insert(name.to_ascii_lowercase(), value.clone()); + added_headers.insert(name.to_ascii_lowercase(), value.clone()); + } + let transformed = result.has_body; + if result.has_body { + body = result.body.clone(); + } + for finding in result.findings { + findings.push(NamespacedFinding { + middleware: entry.name.clone(), + finding, + }); + } + if !result.metadata.is_empty() { + metadata.insert( + entry.name.clone(), + result.metadata.clone().into_iter().collect(), + ); + } + applied.push(MiddlewareInvocation { + name: entry.name.clone(), + implementation: entry.implementation.clone(), + decision: Decision::try_from(result.decision).unwrap_or(Decision::Unspecified), + transformed, + }); + if result.decision == Decision::Deny as i32 { + return Ok(ChainOutcome { + allowed: false, + reason: safe_reason(&result.reason), + body, + added_headers, + findings, + metadata, + applied, + }); + } + } + + Ok(ChainOutcome { + allowed: true, + reason: String::new(), + body, + added_headers, + findings, + metadata, + applied, + }) + } +} + +fn build_evaluation( + entry: &ChainEntry, + input: &HttpRequestInput, + headers: &BTreeMap, + body: &[u8], +) -> HttpRequestEvaluation { + HttpRequestEvaluation { + api_version: API_VERSION.into(), + binding_id: entry.implementation.clone(), + phase: PRE_CREDENTIALS_PHASE.into(), + context: Some(RequestContext { + request_id: input.request_id.clone(), + sandbox_id: input.sandbox_id.clone(), + originating_process: Some(Process { + binary: input.binary.clone(), + pid: input.pid, + ancestors: input.ancestors.clone(), + }), + }), + config: Some(entry.config.clone()), + target: Some(HttpRequestTarget { + scheme: input.scheme.clone(), + host: input.host.clone(), + port: u32::from(input.port), + method: input.method.clone(), + path: input.path.clone(), + query: input.query.clone(), + }), + headers: headers.clone().into_iter().collect(), + body: body.to_vec(), + } +} + +fn validate_header_mutations( + existing_headers: &BTreeMap, + mutations: &HashMap, +) -> Result<()> { + let mut seen = HashSet::new(); + for name in mutations.keys() { + let lower = name.to_ascii_lowercase(); + if !seen.insert(lower.clone()) || existing_headers.contains_key(&lower) { + return Err(miette!( + "middleware cannot rewrite existing header '{name}'" + )); + } + if !is_safe_append_header(&lower) { + return Err(miette!("middleware cannot append unsafe header '{name}'")); + } + } + Ok(()) +} + +fn is_safe_append_header(name: &str) -> bool { + if name.is_empty() + || name.contains(':') + || name.bytes().any(|b| b <= 0x20 || b >= 0x7f) + || matches!( + name, + "authorization" | "cookie" | "host" | "content-length" | "transfer-encoding" + ) + || name.starts_with("x-amz-") + || name.starts_with("x-openshell-credential") + { + return false; + } + name.starts_with("x-openshell-middleware-") +} + +pub(crate) fn safe_reason(reason: &str) -> String { + reason + .chars() + .filter(|ch| ch.is_ascii_alphanumeric() || matches!(ch, '_' | '-' | ':' | ' ')) + .take(160) + .collect() +} + +#[cfg(test)] +mod tests { + use super::*; + use openshell_core::proto::middleware::v1::supervisor_middleware_server::SupervisorMiddleware; + + fn entry(name: &str, on_error: OnError) -> ChainEntry { + ChainEntry { + name: name.into(), + implementation: BUILTIN_SECRETS.into(), + config: prost_types::Struct { + fields: [( + "secrets".into(), + prost_types::Value { + kind: Some(prost_types::value::Kind::StringValue("redact".into())), + }, + )] + .into_iter() + .collect(), + }, + on_error, + } + } + + fn input(body: &str) -> HttpRequestInput { + HttpRequestInput { + request_id: "req".into(), + sandbox_id: "sbx".into(), + binary: "/usr/bin/curl".into(), + pid: 42, + ancestors: vec![], + scheme: "https".into(), + host: "api.example.com".into(), + port: 443, + method: "POST".into(), + path: "/v1".into(), + query: String::new(), + headers: BTreeMap::new(), + body: body.as_bytes().to_vec(), + } + } + + #[tokio::test] + async fn redacts_common_secret_patterns() { + let outcome = ChainRunner::default() + .evaluate( + &[entry("redact", OnError::FailClosed)], + input(r#"{"api_key":"sk-1234567890abcdef"}"#), + ) + .await + .expect("evaluate"); + assert!(outcome.allowed); + assert_eq!( + String::from_utf8(outcome.body).expect("utf8"), + r#"{"api_key":"[REDACTED]"}"# + ); + assert_eq!(outcome.findings[0].finding.count, 1); + } + + #[tokio::test] + async fn transformed_body_feeds_next_stage() { + let entries = [ + entry("first", OnError::FailClosed), + entry("second", OnError::FailClosed), + ]; + let outcome = ChainRunner::default() + .evaluate(&entries, input(r#"password="top-secret""#)) + .await + .expect("evaluate"); + assert!(outcome.allowed); + assert_eq!( + String::from_utf8(outcome.body).expect("utf8"), + r#"password="[REDACTED]""# + ); + assert_eq!(outcome.applied.len(), 2); + } + + #[tokio::test] + async fn fail_open_allows_unavailable_middleware() { + let unavailable = ChainEntry { + name: "missing".into(), + implementation: "third-party/missing".into(), + config: prost_types::Struct::default(), + on_error: OnError::FailOpen, + }; + let outcome = ChainRunner::default() + .evaluate(&[unavailable], input("hello")) + .await + .expect("evaluate"); + assert!(outcome.allowed); + assert_eq!(outcome.body, b"hello"); + } + + #[tokio::test] + async fn fail_closed_denies_unavailable_middleware() { + let unavailable = ChainEntry { + name: "missing".into(), + implementation: "third-party/missing".into(), + config: prost_types::Struct::default(), + on_error: OnError::FailClosed, + }; + let outcome = ChainRunner::default() + .evaluate(&[unavailable], input("hello")) + .await + .expect("evaluate"); + assert!(!outcome.allowed); + assert!(outcome.reason.starts_with("middleware_failed:")); + } + + #[tokio::test] + async fn in_process_service_describes_builtin_binding() { + let manifest = InProcessMiddlewareService + .describe(Request::new(())) + .await + .expect("describe") + .into_inner(); + assert_eq!(manifest.api_version, API_VERSION); + assert_eq!(manifest.bindings[0].id, BUILTIN_SECRETS); + assert_eq!(manifest.bindings[0].operation, HTTP_REQUEST_OPERATION); + assert_eq!(manifest.bindings[0].phase, PRE_CREDENTIALS_PHASE); + } + + #[test] + fn unsafe_header_mutation_is_rejected() { + let err = validate_header_mutations( + &BTreeMap::new(), + &[("Authorization".into(), "Bearer nope".into())] + .into_iter() + .collect(), + ) + .expect_err("unsafe header"); + assert!(err.to_string().contains("unsafe header")); + } +} diff --git a/crates/openshell-supervisor-middleware/src/service.rs b/crates/openshell-supervisor-middleware/src/service.rs new file mode 100644 index 000000000..31cca5694 --- /dev/null +++ b/crates/openshell-supervisor-middleware/src/service.rs @@ -0,0 +1,75 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use openshell_core::proto::middleware::v1::supervisor_middleware_server::SupervisorMiddleware; +use openshell_core::proto::{ + HttpRequestEvaluation, HttpRequestResult, MiddlewareBinding, MiddlewareManifest, + ValidateConfigRequest, ValidateConfigResponse, +}; +use tonic::{Request, Response, Status}; + +use crate::{ + API_VERSION, BUILTIN_SECRETS, HTTP_REQUEST_OPERATION, PRE_CREDENTIALS_PHASE, builtins, + safe_reason, +}; + +#[derive(Debug, Default)] +pub struct InProcessMiddlewareService; + +#[tonic::async_trait] +impl SupervisorMiddleware for InProcessMiddlewareService { + async fn describe( + &self, + _request: Request<()>, + ) -> Result, Status> { + Ok(Response::new(MiddlewareManifest { + api_version: API_VERSION.into(), + name: "openshell/in-process".into(), + service_version: env!("CARGO_PKG_VERSION").into(), + bindings: vec![MiddlewareBinding { + id: BUILTIN_SECRETS.into(), + operation: HTTP_REQUEST_OPERATION.into(), + phase: PRE_CREDENTIALS_PHASE.into(), + }], + })) + } + + async fn validate_config( + &self, + request: Request, + ) -> Result, Status> { + let request = request.into_inner(); + let config = request.config.unwrap_or_default(); + let validation = match request.binding_id.as_str() { + BUILTIN_SECRETS => builtins::secrets::validate_config(&config), + other => Err(miette::miette!( + "middleware implementation '{other}' is not available in phase 1" + )), + }; + Ok(Response::new(match validation { + Ok(()) => ValidateConfigResponse { + valid: true, + reason: String::new(), + }, + Err(err) => ValidateConfigResponse { + valid: false, + reason: safe_reason(&err.to_string()), + }, + })) + } + + async fn evaluate_http_request( + &self, + request: Request, + ) -> Result, Status> { + let request = request.into_inner(); + let result = match request.binding_id.as_str() { + BUILTIN_SECRETS => builtins::secrets::evaluate_http_request(&request), + other => Err(miette::miette!( + "middleware implementation '{other}' is not available in phase 1" + )), + } + .map_err(|err| Status::invalid_argument(safe_reason(&err.to_string())))?; + Ok(Response::new(result)) + } +} diff --git a/crates/openshell-supervisor-network/Cargo.toml b/crates/openshell-supervisor-network/Cargo.toml index 7d0079f7b..fd8fad5f7 100644 --- a/crates/openshell-supervisor-network/Cargo.toml +++ b/crates/openshell-supervisor-network/Cargo.toml @@ -15,6 +15,7 @@ openshell-core = { path = "../openshell-core" } openshell-ocsf = { path = "../openshell-ocsf" } openshell-policy = { path = "../openshell-policy" } openshell-router = { path = "../openshell-router" } +openshell-supervisor-middleware = { path = "../openshell-supervisor-middleware" } apollo-parser = { workspace = true } aws-sigv4 = { version = "1", features = ["sign-http", "http1"] } @@ -28,6 +29,7 @@ glob = { workspace = true } hex = "0.4" ipnet = "2" miette = { workspace = true } +prost-types = { workspace = true } rcgen = { workspace = true } regorus = { version = "0.9", default-features = false, features = ["std", "arc", "glob"] } reqwest = { workspace = true } diff --git a/crates/openshell-supervisor-network/data/sandbox-policy.rego b/crates/openshell-supervisor-network/data/sandbox-policy.rego index efcdf0732..afa4f6947 100644 --- a/crates/openshell-supervisor-network/data/sandbox-policy.rego +++ b/crates/openshell-supervisor-network/data/sandbox-policy.rego @@ -856,6 +856,22 @@ matched_endpoint_config := _matching_endpoint_configs[0] if { count(_matching_endpoint_configs) > 0 } +network_middlewares := object.get(data, "network_middlewares", []) + +_matching_middleware_contexts := [ctx | + some pname + _matching_policy_names[pname] + policy := data.network_policies[pname] + some ep + ep := policy.endpoints[_] + endpoint_matches_request(ep, input.network) + ctx := { + "policy": pname, + "policy_middleware": object.get(policy, "middleware", []), + "endpoint": ep, + } +] + _policy_has_exact_declared_endpoint(policy) if { some ep ep := policy.endpoints[_] @@ -909,7 +925,7 @@ endpoint_path_matches_request(ep, request) if { } # An endpoint has extended config if it specifies L7 protocol, allowed_ips, -# or an explicit tls mode (e.g. tls: skip). +# middleware, or an explicit tls mode (e.g. tls: skip). endpoint_has_extended_config(ep) if { ep.protocol } @@ -918,6 +934,10 @@ endpoint_has_extended_config(ep) if { count(object.get(ep, "allowed_ips", [])) > 0 } +endpoint_has_extended_config(ep) if { + count(object.get(ep, "middleware", [])) > 0 +} + endpoint_has_extended_config(ep) if { ep.tls } diff --git a/crates/openshell-supervisor-network/src/l7/relay.rs b/crates/openshell-supervisor-network/src/l7/relay.rs index ed2bde113..4d501d0a3 100644 --- a/crates/openshell-supervisor-network/src/l7/relay.rs +++ b/crates/openshell-supervisor-network/src/l7/relay.rs @@ -15,9 +15,12 @@ use miette::{IntoDiagnostic, Result, miette}; use openshell_core::activity::{ActivitySender, try_record_activity}; use openshell_core::secrets::{self, SecretResolver}; use openshell_ocsf::{ - ActionId, ActivityId, DispositionId, Endpoint, HttpActivityBuilder, HttpRequest, - NetworkActivityBuilder, SeverityId, StatusId, Url as OcsfUrl, ocsf_emit, + ActionId, ActivityId, DetectionFindingBuilder, DispositionId, Endpoint, FindingInfo, + HttpActivityBuilder, HttpRequest, NetworkActivityBuilder, SeverityId, StatusId, Url as OcsfUrl, + ocsf_emit, }; +use std::collections::BTreeMap; +use std::path::PathBuf; use std::sync::Arc; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tracing::{debug, warn}; @@ -450,6 +453,37 @@ where let _ = &eval_target; if allowed || (config.enforcement == EnforcementMode::Audit && !force_deny) { + let chain = + engine.query_middleware_chain(&middleware_network_input(ctx), &req.target)?; + let req = + match apply_middleware_chain(req, client, ctx, chain, engine.generation_guard()) + .await? + { + MiddlewareApplyResult::Allowed(req) => req, + MiddlewareApplyResult::Denied(reason) => { + crate::l7::rest::RestProvider::default() + .deny_with_redacted_target( + &crate::l7::provider::L7Request { + action: request_info.action.clone(), + target: redacted_target.clone(), + query_params: request_info.query_params.clone(), + raw_header: Vec::new(), + body_length: crate::l7::provider::BodyLength::None, + }, + &ctx.policy_name, + &reason, + client, + Some(&redacted_target), + Some(crate::l7::rest::DenyResponseContext { + host: Some(&ctx.host), + port: Some(ctx.port), + binary: Some(&ctx.binary_path), + }), + ) + .await?; + return Ok(()); + } + }; let outcome = crate::l7::rest::relay_http_request_with_options_guarded( &req, client, @@ -734,6 +768,167 @@ fn jsonrpc_engine_type(protocol: L7Protocol) -> &'static str { } } +enum MiddlewareApplyResult { + Allowed(crate::l7::provider::L7Request), + Denied(String), +} + +async fn apply_middleware_chain( + req: crate::l7::provider::L7Request, + client: &mut C, + ctx: &L7EvalContext, + chain: Vec, + generation_guard: &PolicyGenerationGuard, +) -> Result { + if chain.is_empty() { + return Ok(MiddlewareApplyResult::Allowed(req)); + } + let buffered = + crate::l7::rest::buffer_request_body_for_middleware(&req, client, Some(generation_guard)) + .await?; + let headers = safe_middleware_headers(&buffered.headers)?; + let input = openshell_supervisor_middleware::HttpRequestInput { + request_id: uuid::Uuid::new_v4().to_string(), + sandbox_id: String::new(), + binary: ctx.binary_path.clone(), + pid: 0, + ancestors: ctx.ancestors.clone(), + scheme: "https".into(), + host: ctx.host.clone(), + port: ctx.port, + method: req.action.clone(), + path: req.target.clone(), + query: String::new(), + headers, + body: buffered.body, + }; + let outcome = openshell_supervisor_middleware::ChainRunner::default() + .evaluate(&chain, input) + .await?; + emit_middleware_events(ctx, &req, &outcome); + let rebuilt = crate::l7::rest::rebuild_request_with_buffered_body( + &req, + &buffered.headers, + &outcome.body, + &outcome.added_headers, + )?; + if outcome.allowed { + Ok(MiddlewareApplyResult::Allowed(rebuilt)) + } else { + Ok(MiddlewareApplyResult::Denied(outcome.reason)) + } +} + +fn safe_middleware_headers(headers: &[u8]) -> Result> { + let header_str = + std::str::from_utf8(headers).map_err(|_| miette!("HTTP headers contain invalid UTF-8"))?; + let mut out = BTreeMap::new(); + for line in header_str.lines().skip(1) { + let Some((name, value)) = line.split_once(':') else { + continue; + }; + let name = name.trim().to_ascii_lowercase(); + if name.is_empty() + || matches!( + name.as_str(), + "authorization" | "cookie" | "host" | "content-length" | "transfer-encoding" + ) + || name.starts_with("x-amz-") + || name.starts_with("x-openshell-credential") + { + continue; + } + out.insert(name, value.trim().to_string()); + } + Ok(out) +} + +fn middleware_network_input(ctx: &L7EvalContext) -> crate::opa::NetworkInput { + crate::opa::NetworkInput { + host: ctx.host.clone(), + port: ctx.port, + binary_path: PathBuf::from(&ctx.binary_path), + binary_sha256: String::new(), + ancestors: ctx.ancestors.iter().map(PathBuf::from).collect(), + cmdline_paths: ctx.cmdline_paths.iter().map(PathBuf::from).collect(), + } +} + +fn emit_middleware_events( + ctx: &L7EvalContext, + req: &crate::l7::provider::L7Request, + outcome: &openshell_supervisor_middleware::ChainOutcome, +) { + for invocation in &outcome.applied { + let allowed = invocation.decision == openshell_core::proto::Decision::Allow; + let event = HttpActivityBuilder::new(openshell_ocsf::ctx::ctx()) + .activity(ActivityId::Other) + .action(if allowed { + ActionId::Allowed + } else { + ActionId::Denied + }) + .disposition(if allowed { + DispositionId::Allowed + } else { + DispositionId::Blocked + }) + .severity(if allowed { + SeverityId::Informational + } else { + SeverityId::Medium + }) + .http_request(HttpRequest::new( + &req.action, + OcsfUrl::new("http", &ctx.host, &req.target, ctx.port), + )) + .dst_endpoint(Endpoint::from_domain(&ctx.host, ctx.port)) + .firewall_rule(&ctx.policy_name, "middleware") + .message(format!( + "MIDDLEWARE {} {} decision={:?} transformed={}", + invocation.name, + invocation.implementation, + invocation.decision, + invocation.transformed + )) + .build(); + ocsf_emit!(event); + } + if !outcome.allowed && outcome.reason.starts_with("middleware_failed:") { + let event = DetectionFindingBuilder::new(openshell_ocsf::ctx::ctx()) + .severity(SeverityId::High) + .finding_info(FindingInfo::new( + "openshell.middleware.failure", + "Supervisor middleware failure", + )) + .message("Required supervisor middleware failed closed") + .build(); + ocsf_emit!(event); + } + for finding in &outcome.findings { + let event = DetectionFindingBuilder::new(openshell_ocsf::ctx::ctx()) + .severity(match finding.finding.severity.as_str() { + "high" => SeverityId::High, + "low" => SeverityId::Low, + _ => SeverityId::Medium, + }) + .finding_info(FindingInfo::new( + &finding.finding.r#type, + &finding.finding.label, + )) + .evidence_pairs(&[ + ("middleware", &finding.middleware), + ("count", &finding.finding.count.to_string()), + ]) + .message(format!( + "Middleware finding {} count={}", + finding.finding.r#type, finding.finding.count + )) + .build(); + ocsf_emit!(event); + } +} + /// REST relay loop: parse request -> evaluate -> allow/deny -> relay response -> repeat. async fn relay_rest( config: &L7EndpointConfig, @@ -903,6 +1098,37 @@ where let _ = &eval_target; if allowed || config.enforcement == EnforcementMode::Audit { + let chain = + engine.query_middleware_chain(&middleware_network_input(ctx), &req.target)?; + let req = + match apply_middleware_chain(req, client, ctx, chain, engine.generation_guard()) + .await? + { + MiddlewareApplyResult::Allowed(req) => req, + MiddlewareApplyResult::Denied(reason) => { + provider + .deny_with_redacted_target( + &crate::l7::provider::L7Request { + action: request_info.action.clone(), + target: redacted_target.clone(), + query_params: request_info.query_params.clone(), + raw_header: Vec::new(), + body_length: crate::l7::provider::BodyLength::None, + }, + &ctx.policy_name, + &reason, + client, + Some(&redacted_target), + Some(crate::l7::rest::DenyResponseContext { + host: Some(&ctx.host), + port: Some(ctx.port), + binary: Some(&ctx.binary_path), + }), + ) + .await?; + return Ok(()); + } + }; let req_with_auth = match crate::l7::token_grant_injection::inject_if_needed(req, ctx).await { Ok(req) => req, @@ -1336,6 +1562,37 @@ where let _ = &eval_target; if allowed || (config.enforcement == EnforcementMode::Audit && !force_deny) { + let chain = + engine.query_middleware_chain(&middleware_network_input(ctx), &req.target)?; + let req = + match apply_middleware_chain(req, client, ctx, chain, engine.generation_guard()) + .await? + { + MiddlewareApplyResult::Allowed(req) => req, + MiddlewareApplyResult::Denied(reason) => { + crate::l7::rest::RestProvider::default() + .deny_with_redacted_target( + &crate::l7::provider::L7Request { + action: request_info.action.clone(), + target: redacted_target.clone(), + query_params: request_info.query_params.clone(), + raw_header: Vec::new(), + body_length: crate::l7::provider::BodyLength::None, + }, + &ctx.policy_name, + &reason, + client, + Some(&redacted_target), + Some(crate::l7::rest::DenyResponseContext { + host: Some(&ctx.host), + port: Some(ctx.port), + binary: Some(&ctx.binary_path), + }), + ) + .await?; + return Ok(()); + } + }; let outcome = crate::l7::rest::relay_http_request_with_resolver_guarded( &req, client, @@ -1674,6 +1931,7 @@ pub async fn relay_passthrough_with_credentials( upstream: &mut U, ctx: &L7EvalContext, generation_guard: &PolicyGenerationGuard, + middleware_engine: Option<&crate::opa::OpaEngine>, ) -> Result<()> where C: AsyncRead + AsyncWrite + Unpin + Send, @@ -1756,6 +2014,43 @@ where ocsf_emit!(event); } + let req = if let Some(engine) = middleware_engine { + let input = middleware_network_input(ctx); + let (chain, generation) = + engine.query_middleware_chain_with_generation(&input, &req.target)?; + if generation != generation_guard.captured_generation() { + return Ok(()); + } + match apply_middleware_chain(req, client, ctx, chain, generation_guard).await? { + MiddlewareApplyResult::Allowed(req) => req, + MiddlewareApplyResult::Denied(reason) => { + crate::l7::rest::RestProvider::default() + .deny_with_redacted_target( + &crate::l7::provider::L7Request { + action: "HTTP".into(), + target: redacted_target.clone(), + query_params: std::collections::HashMap::new(), + raw_header: Vec::new(), + body_length: crate::l7::provider::BodyLength::None, + }, + &ctx.policy_name, + &reason, + client, + Some(&redacted_target), + Some(crate::l7::rest::DenyResponseContext { + host: Some(&ctx.host), + port: Some(ctx.port), + binary: Some(&ctx.binary_path), + }), + ) + .await?; + return Ok(()); + } + } + } else { + req + }; + let req_with_auth = match crate::l7::token_grant_injection::inject_if_needed(req, ctx).await { Ok(req) => req, @@ -1901,6 +2196,63 @@ network_policies: (config, tunnel_engine, ctx, fixture) } + fn middleware_relay_context( + middleware_impl: &str, + on_error: &str, + ) -> (L7EndpointConfig, TunnelPolicyEngine, L7EvalContext) { + let data = format!( + r#" +network_middlewares: + - name: request-middleware + middleware: {middleware_impl} + on_error: {on_error} +network_policies: + rest_api: + name: rest_api + middleware: ["request-middleware"] + endpoints: + - host: api.example.test + port: 8080 + protocol: rest + enforcement: enforce + rules: + - allow: + method: POST + path: "/v1/**" + binaries: + - {{ path: /usr/bin/curl }} +"# + ); + let engine = OpaEngine::from_strings(TEST_POLICY, &data).unwrap(); + let input = NetworkInput { + host: "api.example.test".into(), + port: 8080, + binary_path: PathBuf::from("/usr/bin/curl"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let (endpoint_config, generation) = engine + .query_endpoint_config_with_generation(&input) + .unwrap(); + let config = crate::l7::parse_l7_config(&endpoint_config.unwrap()).unwrap(); + let tunnel_engine = engine.clone_engine_for_tunnel(generation).unwrap(); + let ctx = L7EvalContext { + host: "api.example.test".into(), + port: 8080, + policy_name: "rest_api".into(), + binary_path: "/usr/bin/curl".into(), + ancestors: vec![], + cmdline_paths: vec![], + secret_resolver: None, + activity_tx: None, + dynamic_credentials: None, + token_grant_resolver: None, + }; + + (config, tunnel_engine, ctx) + } + fn passthrough_token_grant_relay_context( resolver_response: std::result::Result<&str, &str>, ) -> ( @@ -2112,7 +2464,10 @@ network_policies: .unwrap(); let upstream_request = String::from_utf8_lossy(&upstream_request[..n]); - assert!(upstream_request.starts_with("GET /v1/projects HTTP/1.1\r\n")); + assert!( + upstream_request.starts_with("GET /v1/projects HTTP/1.1\r\n"), + "unexpected upstream request: {upstream_request:?}" + ); assert!(upstream_request.contains("Authorization: Bearer grant-token\r\n")); assert!(!upstream_request.contains("stale-token")); assert_eq!(authorization_header_count(&upstream_request), 1); @@ -2194,6 +2549,115 @@ network_policies: fixture.assert_one_request("api.example.test\t8080\t/v1/**\tprovider:access_token"); } + #[tokio::test] + async fn l7_rest_middleware_redacts_body_before_upstream() { + let (config, tunnel_engine, ctx) = + middleware_relay_context("openshell/secrets", "fail_closed"); + let (mut app, mut relay_client) = tokio::io::duplex(8192); + let (mut relay_upstream, mut upstream) = tokio::io::duplex(8192); + let relay = tokio::spawn(async move { + relay_with_inspection( + &config, + tunnel_engine, + &mut relay_client, + &mut relay_upstream, + &ctx, + ) + .await + }); + + let body = br#"{"api_key":"sk-1234567890abcdef"}"#; + let request = format!( + "POST /v1/messages HTTP/1.1\r\nHost: api.example.test\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}", + body.len(), + std::str::from_utf8(body).unwrap() + ); + app.write_all(request.as_bytes()).await.unwrap(); + + let mut upstream_request = [0u8; 1024]; + let n = tokio::time::timeout( + std::time::Duration::from_secs(1), + upstream.read(&mut upstream_request), + ) + .await + .expect("request should reach upstream") + .unwrap(); + let upstream_request = String::from_utf8_lossy(&upstream_request[..n]); + assert!(upstream_request.contains(r#""api_key":"[REDACTED]""#)); + assert!(!upstream_request.contains("sk-1234567890abcdef")); + + upstream + .write_all(b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\nConnection: close\r\n\r\n") + .await + .unwrap(); + let mut client_response = [0u8; 512]; + let n = tokio::time::timeout( + std::time::Duration::from_secs(1), + app.read(&mut client_response), + ) + .await + .expect("response should reach client") + .unwrap(); + assert!(String::from_utf8_lossy(&client_response[..n]).contains("204 No Content")); + drop(app); + tokio::time::timeout(std::time::Duration::from_secs(1), relay) + .await + .expect("relay should finish") + .unwrap() + .unwrap(); + } + + #[tokio::test] + async fn l7_rest_middleware_fail_closed_does_not_reach_upstream() { + let (config, tunnel_engine, ctx) = + middleware_relay_context("example/unavailable", "fail_closed"); + let (mut app, mut relay_client) = tokio::io::duplex(8192); + let (mut relay_upstream, mut upstream) = tokio::io::duplex(8192); + let relay = tokio::spawn(async move { + relay_with_inspection( + &config, + tunnel_engine, + &mut relay_client, + &mut relay_upstream, + &ctx, + ) + .await + }); + + app.write_all( + b"POST /v1/messages HTTP/1.1\r\nHost: api.example.test\r\nContent-Length: 2\r\nConnection: close\r\n\r\n{}", + ) + .await + .unwrap(); + + let mut response = [0u8; 512]; + let n = tokio::time::timeout(std::time::Duration::from_secs(1), app.read(&mut response)) + .await + .expect("denial should reach client") + .unwrap(); + let response = String::from_utf8_lossy(&response[..n]); + assert!(response.contains("403 Forbidden")); + assert!(response.contains("middleware_failed")); + + let mut upstream_request = [0u8; 32]; + let result = tokio::time::timeout( + std::time::Duration::from_millis(100), + upstream.read(&mut upstream_request), + ) + .await; + assert!( + matches!(result, Err(_) | Ok(Ok(0))), + "upstream should not receive request bytes" + ); + + drop(app); + tokio::time::timeout(std::time::Duration::from_secs(1), relay) + .await + .expect("relay should finish") + .unwrap() + .unwrap(); + } + #[tokio::test] async fn passthrough_relay_injects_token_grant_authorization_header() { let (generation_guard, ctx, fixture) = @@ -2206,6 +2670,7 @@ network_policies: &mut relay_upstream, &ctx, &generation_guard, + None, ) .await }); @@ -2268,6 +2733,7 @@ network_policies: &mut relay_upstream, &ctx, &generation_guard, + None, ) .await }); @@ -3173,7 +3639,10 @@ network_policies: .expect("first request should reach upstream") .unwrap(); let first_upstream = String::from_utf8_lossy(&first_upstream[..n]); - assert!(first_upstream.starts_with("POST /write HTTP/1.1")); + assert!( + first_upstream.starts_with("POST /write HTTP/1.1"), + "unexpected upstream request: {first_upstream:?}" + ); upstream .write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\nConnection: keep-alive\r\n\r\nOK") @@ -3243,6 +3712,7 @@ network_policies: &mut relay_upstream, &ctx, &generation_guard, + None, ) .await }); diff --git a/crates/openshell-supervisor-network/src/l7/rest.rs b/crates/openshell-supervisor-network/src/l7/rest.rs index 0558a67e5..1a4036abd 100644 --- a/crates/openshell-supervisor-network/src/l7/rest.rs +++ b/crates/openshell-supervisor-network/src/l7/rest.rs @@ -27,6 +27,7 @@ const MAX_REWRITE_BODY_BYTES: usize = 256 * 1024; /// Maximum body bytes for `SigV4` body-signing mode. Larger than the credential /// rewrite limit because Bedrock payloads can be several megabytes. const MAX_SIGV4_BODY_BYTES: usize = 10 * 1024 * 1024; +pub(crate) const MAX_MIDDLEWARE_BODY_BYTES: usize = MAX_REWRITE_BODY_BYTES; const RELAY_BUF_SIZE: usize = 8192; const HTTP_METHOD_PREFIXES: &[&[u8]] = &[ b"GET ", @@ -768,6 +769,83 @@ struct PreparedRequestBody { body: Vec, } +pub(crate) struct BufferedRequestBody { + pub(crate) headers: Vec, + pub(crate) body: Vec, +} + +pub(crate) async fn buffer_request_body_for_middleware( + req: &L7Request, + client: &mut C, + generation_guard: Option<&PolicyGenerationGuard>, +) -> Result { + let header_end = req + .raw_header + .windows(4) + .position(|w| w == b"\r\n\r\n") + .map_or(req.raw_header.len(), |p| p + 4); + let headers = req.raw_header[..header_end].to_vec(); + let already_read = &req.raw_header[header_end..]; + match req.body_length { + BodyLength::None => Ok(BufferedRequestBody { + headers, + body: already_read.to_vec(), + }), + BodyLength::ContentLength(len) => { + let len = usize::try_from(len) + .map_err(|_| miette!("request body is too large for middleware"))?; + if len > MAX_MIDDLEWARE_BODY_BYTES { + return Err(miette!( + "middleware buffers at most {MAX_MIDDLEWARE_BODY_BYTES} request body bytes" + )); + } + let initial_len = already_read.len().min(len); + let mut body = Vec::with_capacity(len); + body.extend_from_slice(&already_read[..initial_len]); + let mut remaining = len.saturating_sub(initial_len); + let mut buf = [0u8; RELAY_BUF_SIZE]; + while remaining > 0 { + let to_read = remaining.min(buf.len()); + let n = client.read(&mut buf[..to_read]).await.into_diagnostic()?; + if n == 0 { + return Err(miette!( + "Connection closed with {remaining} body bytes remaining" + )); + } + if let Some(guard) = generation_guard { + guard.ensure_current()?; + } + body.extend_from_slice(&buf[..n]); + remaining -= n; + } + Ok(BufferedRequestBody { headers, body }) + } + BodyLength::Chunked => { + let body = collect_chunked_body(client, already_read, generation_guard).await?; + Ok(BufferedRequestBody { headers, body }) + } + } +} + +pub(crate) fn rebuild_request_with_buffered_body( + req: &L7Request, + headers: &[u8], + body: &[u8], + add_headers: &std::collections::BTreeMap, +) -> Result { + let mut header_bytes = set_content_length(headers, body.len())?; + header_bytes = strip_header(&header_bytes, "transfer-encoding")?; + header_bytes = append_headers(&header_bytes, add_headers)?; + header_bytes.extend_from_slice(body); + Ok(L7Request { + action: req.action.clone(), + target: req.target.clone(), + query_params: req.query_params.clone(), + raw_header: header_bytes, + body_length: BodyLength::ContentLength(body.len() as u64), + }) +} + async fn collect_and_rewrite_request_body( req: &L7Request, client: &mut C, @@ -1160,6 +1238,50 @@ fn set_content_length(headers: &[u8], len: usize) -> Result> { Ok(out.into_bytes()) } +fn strip_header(headers: &[u8], strip_name: &str) -> Result> { + let header_str = + std::str::from_utf8(headers).map_err(|_| miette!("HTTP headers contain invalid UTF-8"))?; + let mut out = String::with_capacity(header_str.len()); + for line in header_str.split("\r\n") { + if line.is_empty() { + out.push_str("\r\n"); + break; + } + if line + .split_once(':') + .is_some_and(|(name, _)| name.trim().eq_ignore_ascii_case(strip_name)) + { + continue; + } + out.push_str(line); + out.push_str("\r\n"); + } + Ok(out.into_bytes()) +} + +fn append_headers( + headers: &[u8], + add_headers: &std::collections::BTreeMap, +) -> Result> { + if add_headers.is_empty() { + return Ok(headers.to_vec()); + } + let split = headers + .windows(4) + .position(|w| w == b"\r\n\r\n") + .map_or(headers.len(), |pos| pos); + let mut out = Vec::with_capacity(headers.len() + add_headers.len() * 32); + out.extend_from_slice(&headers[..split]); + for (name, value) in add_headers { + out.extend_from_slice(b"\r\n"); + out.extend_from_slice(name.as_bytes()); + out.extend_from_slice(b": "); + out.extend_from_slice(value.as_bytes()); + } + out.extend_from_slice(b"\r\n\r\n"); + Ok(out) +} + pub(crate) fn request_is_websocket_upgrade(raw_header: &[u8]) -> bool { let header_end = raw_header .windows(4) diff --git a/crates/openshell-supervisor-network/src/opa.rs b/crates/openshell-supervisor-network/src/opa.rs index fbab5fedd..451c57e59 100644 --- a/crates/openshell-supervisor-network/src/opa.rs +++ b/crates/openshell-supervisor-network/src/opa.rs @@ -13,6 +13,8 @@ use openshell_core::policy::{ }; use openshell_core::proto::SandboxPolicy as ProtoSandboxPolicy; use openshell_policy::L7ConfigStanza; +use openshell_supervisor_middleware::ChainEntry; +use std::collections::HashSet; use std::path::{Path, PathBuf}; use std::sync::{ Arc, Mutex, @@ -132,6 +134,19 @@ impl TunnelPolicyEngine { pub(crate) fn engine(&self) -> &Mutex { &self.engine } + + /// Query the ordered middleware chain for a request path within this tunnel. + pub fn query_middleware_chain( + &self, + input: &NetworkInput, + request_path: &str, + ) -> Result> { + let mut engine = self + .engine + .lock() + .map_err(|_| miette::miette!("OPA engine lock poisoned"))?; + query_middleware_chain_locked(&mut engine, input, request_path) + } } impl OpaEngine { @@ -200,6 +215,14 @@ impl OpaEngine { .map_err(|e| miette::miette!("internal: failed to parse proto JSON: {e}"))?; // Validate BEFORE expanding presets + let middleware_errors = validate_middleware_policies(&data); + if !middleware_errors.is_empty() { + return Err(miette::miette!( + "middleware policy validation failed:\n{}", + middleware_errors.join("\n") + )); + } + let (errors, warnings) = crate::l7::validate_l7_policies(&data); for w in &warnings { openshell_ocsf::ocsf_emit!( @@ -548,6 +571,21 @@ impl OpaEngine { } } + /// Query the ordered middleware chain for a parsed HTTP request path. + pub fn query_middleware_chain_with_generation( + &self, + input: &NetworkInput, + request_path: &str, + ) -> Result<(Vec, u64)> { + let mut engine = self + .engine + .lock() + .map_err(|_| miette::miette!("OPA engine lock poisoned"))?; + let generation = self.current_generation(); + let chain = query_middleware_chain_locked(&mut engine, input, request_path)?; + Ok((chain, generation)) + } + /// Query `allowed_ips` from the matched endpoint config for a given request. /// /// Returns the list of CIDR/IP strings from the endpoint's `allowed_ips` @@ -687,6 +725,243 @@ fn get_str_array(val: ®orus::Value, key: &str) -> Vec { } } +fn network_input_json(input: &NetworkInput) -> serde_json::Value { + let ancestor_strs: Vec = input + .ancestors + .iter() + .map(|p| p.to_string_lossy().into_owned()) + .collect(); + let cmdline_strs: Vec = input + .cmdline_paths + .iter() + .map(|p| p.to_string_lossy().into_owned()) + .collect(); + serde_json::json!({ + "exec": { + "path": input.binary_path.to_string_lossy(), + "ancestors": ancestor_strs, + "cmdline_paths": cmdline_strs, + }, + "network": { + "host": input.host, + "port": input.port, + } + }) +} + +#[derive(Debug, Clone)] +struct MiddlewareContext { + policy_middleware: Vec, + endpoint_middleware: Vec, + endpoint_path: String, +} + +fn query_middleware_chain_locked( + engine: &mut regorus::Engine, + input: &NetworkInput, + request_path: &str, +) -> Result> { + engine + .set_input_json(&network_input_json(input).to_string()) + .map_err(|e| miette::miette!("{e}"))?; + + let configs_val = engine + .eval_rule("data.openshell.sandbox.network_middlewares".into()) + .map_err(|e| miette::miette!("{e}"))?; + let configs = parse_middleware_configs(&configs_val)?; + if configs.is_empty() { + return Ok(Vec::new()); + } + let contexts_val = engine + .eval_rule("data.openshell.sandbox._matching_middleware_contexts".into()) + .map_err(|e| miette::miette!("{e}"))?; + let contexts = parse_middleware_contexts(&contexts_val); + let Some(context) = select_middleware_context(&contexts, request_path) else { + return Ok(global_middleware_entries( + &configs, + &input.host, + &HashSet::new(), + )?); + }; + + let mut explicit = Vec::new(); + for name in context + .policy_middleware + .iter() + .chain(context.endpoint_middleware.iter()) + { + if !explicit.contains(name) { + explicit.push(name.clone()); + } + } + let explicit_set: HashSet = explicit.iter().cloned().collect(); + let mut ordered = global_middleware_entries(&configs, &input.host, &explicit_set)?; + for name in explicit { + if !ordered.iter().any(|entry| entry.name == name) { + let config = configs + .iter() + .find(|config| get_str(config, "name").as_deref() == Some(name.as_str())) + .ok_or_else(|| miette::miette!("unknown middleware config '{name}'"))?; + ordered.push(chain_entry_from_value(config)?); + } + } + Ok(ordered) +} + +fn parse_middleware_configs(value: ®orus::Value) -> Result> { + match value { + regorus::Value::Undefined => Ok(Vec::new()), + regorus::Value::Array(values) => Ok(values.to_vec()), + other => Err(miette::miette!( + "network_middlewares must be an array, got {other:?}" + )), + } +} + +fn parse_middleware_contexts(value: ®orus::Value) -> Vec { + let regorus::Value::Array(values) = value else { + return Vec::new(); + }; + values + .iter() + .filter_map(|value| { + let regorus::Value::Object(_) = value else { + return None; + }; + let endpoint = get_field(value, "endpoint")?; + Some(MiddlewareContext { + policy_middleware: get_str_array(value, "policy_middleware"), + endpoint_middleware: get_str_array(endpoint, "middleware"), + endpoint_path: get_str(endpoint, "path").unwrap_or_default(), + }) + }) + .collect() +} + +fn select_middleware_context<'a>( + contexts: &'a [MiddlewareContext], + request_path: &str, +) -> Option<&'a MiddlewareContext> { + contexts + .iter() + .filter(|context| crate::l7::endpoint_path_matches(&context.endpoint_path, request_path)) + .max_by_key(|context| { + if context.endpoint_path.is_empty() { + 0 + } else { + context.endpoint_path.chars().filter(|c| *c != '*').count() + } + }) +} + +fn global_middleware_entries( + configs: &[regorus::Value], + host: &str, + explicit: &HashSet, +) -> Result> { + let mut entries = Vec::new(); + for config in configs { + let name = get_str(config, "name").unwrap_or_default(); + if explicit.contains(&name) { + continue; + } + if middleware_selector_matches(config, host) { + entries.push(chain_entry_from_value(config)?); + } + } + Ok(entries) +} + +fn middleware_selector_matches(config: ®orus::Value, host: &str) -> bool { + let Some(selector) = get_field(config, "endpoints") else { + return false; + }; + let includes = get_str_array(selector, "include"); + let excludes = get_str_array(selector, "exclude"); + let included = + !includes.is_empty() && includes.iter().any(|pattern| host_matches(pattern, host)); + let excluded = excludes.iter().any(|pattern| host_matches(pattern, host)); + included && !excluded +} + +fn host_matches(pattern: &str, host: &str) -> bool { + if pattern == "*" || pattern == "**" { + return true; + } + if !pattern.contains('*') { + return pattern.eq_ignore_ascii_case(host); + } + glob::Pattern::new(&pattern.to_ascii_lowercase()) + .is_ok_and(|pattern| pattern.matches(&host.to_ascii_lowercase())) +} + +fn chain_entry_from_value(value: ®orus::Value) -> Result { + let name = get_str(value, "name").unwrap_or_default(); + let implementation = get_str(value, "middleware").unwrap_or_default(); + Ok(ChainEntry { + name, + implementation, + config: get_field(value, "config") + .map(regorus_value_to_struct) + .unwrap_or_default(), + on_error: openshell_supervisor_middleware::OnError::parse( + get_str(value, "on_error").as_deref().unwrap_or_default(), + )?, + }) +} + +fn get_field<'a>(val: &'a regorus::Value, key: &str) -> Option<&'a regorus::Value> { + let key_val = regorus::Value::String(key.into()); + match val { + regorus::Value::Object(map) => map.get(&key_val), + _ => None, + } +} + +fn regorus_value_to_struct(value: ®orus::Value) -> prost_types::Struct { + let regorus::Value::Object(map) = value else { + return prost_types::Struct::default(); + }; + prost_types::Struct { + fields: map + .iter() + .filter_map(|(key, value)| match key { + regorus::Value::String(key) => { + Some((key.to_string(), regorus_value_to_prost(value))) + } + _ => None, + }) + .collect(), + } +} + +fn regorus_value_to_prost(value: ®orus::Value) -> prost_types::Value { + use prost_types::{ListValue, Struct, Value, value::Kind}; + Value { + kind: Some(match value { + regorus::Value::Bool(value) => Kind::BoolValue(*value), + regorus::Value::Number(value) => Kind::NumberValue(value.as_f64().unwrap_or_default()), + regorus::Value::String(value) => Kind::StringValue(value.to_string()), + regorus::Value::Array(values) => Kind::ListValue(ListValue { + values: values.iter().map(regorus_value_to_prost).collect(), + }), + regorus::Value::Object(values) => Kind::StructValue(Struct { + fields: values + .iter() + .filter_map(|(key, value)| match key { + regorus::Value::String(key) => { + Some((key.to_string(), regorus_value_to_prost(value))) + } + _ => None, + }) + .collect(), + }), + regorus::Value::Null | regorus::Value::Undefined => Kind::NullValue(0), + _ => Kind::NullValue(0), + }), + } +} + fn parse_filesystem_policy(val: ®orus::Value) -> FilesystemPolicy { FilesystemPolicy { read_only: get_str_array(val, "read_only") @@ -735,6 +1010,14 @@ fn preprocess_yaml_data(yaml_str: &str) -> Result { } // Validate BEFORE expanding presets (catches user errors like rules+access) + let middleware_errors = validate_middleware_policies(&data); + if !middleware_errors.is_empty() { + return Err(miette::miette!( + "middleware policy validation failed:\n{}", + middleware_errors.join("\n") + )); + } + let (errors, warnings) = crate::l7::validate_l7_policies(&data); for w in &warnings { openshell_ocsf::ocsf_emit!( @@ -955,6 +1238,131 @@ fn normalize_l7_rule_aliases( } } +fn validate_middleware_policies(data: &serde_json::Value) -> Vec { + let mut errors = Vec::new(); + let middlewares = data + .get("network_middlewares") + .and_then(serde_json::Value::as_array) + .map_or(&[][..], Vec::as_slice); + let mut names = HashSet::new(); + for mw in middlewares { + let name = mw + .get("name") + .and_then(serde_json::Value::as_str) + .unwrap_or_default(); + let implementation = mw + .get("middleware") + .and_then(serde_json::Value::as_str) + .unwrap_or_default(); + if name.is_empty() { + errors.push("network_middlewares entry has empty name".to_string()); + } else if !names.insert(name.to_string()) { + errors.push(format!("duplicate middleware config '{name}'")); + } + if implementation.is_empty() { + errors.push(format!( + "middleware config '{name}' has empty implementation" + )); + } + if implementation.starts_with("openshell/") + && implementation != openshell_supervisor_middleware::BUILTIN_SECRETS + { + errors.push(format!( + "middleware config '{name}' references unsupported built-in '{implementation}'" + )); + } + let on_error = mw + .get("on_error") + .and_then(serde_json::Value::as_str) + .unwrap_or_default(); + if !matches!(on_error, "" | "fail_closed" | "fail_open") { + errors.push(format!( + "middleware config '{name}' has invalid on_error '{on_error}'" + )); + } + } + + let Some(policies) = data + .get("network_policies") + .and_then(serde_json::Value::as_object) + else { + return errors; + }; + + for (policy_name, policy) in policies { + let policy_middleware = json_string_array(policy.get("middleware")); + for name in &policy_middleware { + if !names.contains(name) { + errors.push(format!( + "network policy '{policy_name}' references unknown middleware config '{name}'" + )); + } + } + for endpoint in policy + .get("endpoints") + .and_then(serde_json::Value::as_array) + .map_or(&[][..], Vec::as_slice) + { + let endpoint_middleware = json_string_array(endpoint.get("middleware")); + for name in &endpoint_middleware { + if !names.contains(name) { + errors.push(format!( + "network policy '{policy_name}' endpoint references unknown middleware config '{name}'" + )); + } + } + let tls_skip = endpoint + .get("tls") + .and_then(serde_json::Value::as_str) + .is_some_and(|tls| tls == "skip"); + if tls_skip && (!policy_middleware.is_empty() || !endpoint_middleware.is_empty()) { + errors.push(format!( + "network policy '{policy_name}' attaches middleware to a tls: skip endpoint" + )); + } + if tls_skip && global_selector_matches_any_middleware(middlewares, endpoint) { + errors.push(format!( + "network policy '{policy_name}' tls: skip endpoint matches a global middleware selector" + )); + } + } + } + errors +} + +fn json_string_array(value: Option<&serde_json::Value>) -> Vec { + value + .and_then(serde_json::Value::as_array) + .map(|values| { + values + .iter() + .filter_map(serde_json::Value::as_str) + .map(ToString::to_string) + .collect() + }) + .unwrap_or_default() +} + +fn global_selector_matches_any_middleware( + middlewares: &[serde_json::Value], + endpoint: &serde_json::Value, +) -> bool { + let host = endpoint + .get("host") + .and_then(serde_json::Value::as_str) + .unwrap_or_default(); + middlewares.iter().any(|mw| { + let Some(selector) = mw.get("endpoints") else { + return false; + }; + let includes = json_string_array(selector.get("include")); + let excludes = json_string_array(selector.get("exclude")); + !includes.is_empty() + && includes.iter().any(|pattern| host_matches(pattern, host)) + && !excludes.iter().any(|pattern| host_matches(pattern, host)) + }) +} + /// Resolve a policy binary path through the container's root filesystem. /// /// On Linux, `/proc//root/` provides access to the container's mount @@ -1316,6 +1724,9 @@ fn proto_to_opa_data_json(proto: &ProtoSandboxPolicy, entrypoint_pid: u32) -> St allow_all_known_mcp_methods.into(); } } + if !e.middleware.is_empty() { + ep["middleware"] = e.middleware.clone().into(); + } ep }) .collect(); @@ -1341,14 +1752,43 @@ fn proto_to_opa_data_json(proto: &ProtoSandboxPolicy, entrypoint_pid: u32) -> St entries }) .collect(); - ( - key.clone(), - serde_json::json!({ - "name": rule.name, - "endpoints": endpoints, - "binaries": binaries, - }), - ) + let mut policy = serde_json::json!({ + "name": rule.name, + "endpoints": endpoints, + "binaries": binaries, + }); + if !rule.middleware.is_empty() { + policy["middleware"] = rule.middleware.clone().into(); + } + (key.clone(), policy) + }) + .collect(); + + let network_middlewares: Vec = proto + .network_middlewares + .iter() + .map(|mw| { + let mut value = serde_json::json!({ + "name": mw.name, + "middleware": mw.middleware, + }); + if let Some(config) = &mw.config { + value["config"] = prost_struct_to_json(config); + } + if !mw.on_error.is_empty() { + value["on_error"] = mw.on_error.clone().into(); + } + if let Some(selector) = &mw.endpoints { + let mut endpoints = serde_json::json!({}); + if !selector.include.is_empty() { + endpoints["include"] = selector.include.clone().into(); + } + if !selector.exclude.is_empty() { + endpoints["exclude"] = selector.exclude.clone().into(); + } + value["endpoints"] = endpoints; + } + value }) .collect(); @@ -1357,10 +1797,37 @@ fn proto_to_opa_data_json(proto: &ProtoSandboxPolicy, entrypoint_pid: u32) -> St "landlock": landlock, "process": process, "network_policies": network_policies, + "network_middlewares": network_middlewares, }) .to_string() } +fn prost_struct_to_json(config: &prost_types::Struct) -> serde_json::Value { + serde_json::Value::Object( + config + .fields + .iter() + .map(|(key, value)| (key.clone(), prost_value_to_json(value))) + .collect(), + ) +} + +fn prost_value_to_json(value: &prost_types::Value) -> serde_json::Value { + match value.kind.as_ref() { + Some(prost_types::value::Kind::NullValue(_)) | None => serde_json::Value::Null, + Some(prost_types::value::Kind::BoolValue(value)) => serde_json::Value::Bool(*value), + Some(prost_types::value::Kind::NumberValue(value)) => serde_json::Number::from_f64(*value) + .map_or(serde_json::Value::Null, serde_json::Value::Number), + Some(prost_types::value::Kind::StringValue(value)) => { + serde_json::Value::String(value.clone()) + } + Some(prost_types::value::Kind::ListValue(value)) => { + serde_json::Value::Array(value.values.iter().map(prost_value_to_json).collect()) + } + Some(prost_types::value::Kind::StructValue(value)) => prost_struct_to_json(value), + } +} + #[cfg(test)] #[allow( clippy::needless_raw_string_hashes, @@ -1407,6 +1874,7 @@ mod tests { path: "/usr/local/bin/claude".to_string(), ..Default::default() }], + ..Default::default() }, ); network_policies.insert( @@ -1422,6 +1890,7 @@ mod tests { path: "/usr/bin/glab".to_string(), ..Default::default() }], + ..Default::default() }, ); ProtoSandboxPolicy { @@ -1439,6 +1908,7 @@ mod tests { run_as_group: "sandbox".to_string(), }), network_policies, + network_middlewares: vec![], } } @@ -2763,6 +3233,7 @@ network_policies: path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }, ); @@ -2781,6 +3252,7 @@ network_policies: run_as_group: "sandbox".to_string(), }), network_policies, + network_middlewares: vec![], }; let engine = OpaEngine::from_proto(&proto).expect("engine from proto"); @@ -3783,6 +4255,7 @@ network_policies: path: "/usr/bin/node".to_string(), ..Default::default() }], + ..Default::default() }, ); let proto = ProtoSandboxPolicy { @@ -3800,6 +4273,7 @@ network_policies: run_as_group: "sandbox".to_string(), }), network_policies, + network_middlewares: vec![], }; let engine = OpaEngine::from_proto(&proto).expect("engine from proto"); @@ -3840,6 +4314,7 @@ network_policies: path: "/usr/bin/node".to_string(), ..Default::default() }], + ..Default::default() }, ); let proto = ProtoSandboxPolicy { @@ -3857,6 +4332,7 @@ network_policies: run_as_group: "sandbox".to_string(), }), network_policies, + network_middlewares: vec![], }; let engine = OpaEngine::from_proto(&proto).expect("engine from proto"); @@ -3898,6 +4374,7 @@ network_policies: path: "/usr/local/bin/claude".to_string(), ..Default::default() }], + middleware: vec![], }, ); let proto = ProtoSandboxPolicy { @@ -3915,6 +4392,7 @@ network_policies: run_as_group: "sandbox".to_string(), }), network_policies, + network_middlewares: vec![], }; let engine = OpaEngine::from_proto(&proto).expect("engine from proto"); @@ -3958,6 +4436,7 @@ network_policies: path: "/usr/local/bin/aws".to_string(), ..Default::default() }], + middleware: vec![], }, ); let proto = ProtoSandboxPolicy { @@ -3975,6 +4454,7 @@ network_policies: run_as_group: "sandbox".to_string(), }), network_policies, + network_middlewares: vec![], }; let engine = OpaEngine::from_proto(&proto).expect("engine from proto"); @@ -4017,6 +4497,7 @@ network_policies: path: "/usr/bin/node".to_string(), ..Default::default() }], + ..Default::default() }, ); let proto = ProtoSandboxPolicy { @@ -4034,6 +4515,7 @@ network_policies: run_as_group: "sandbox".to_string(), }), network_policies, + network_middlewares: vec![], }; let engine = OpaEngine::from_proto(&proto).expect("engine from proto"); @@ -4966,6 +5448,7 @@ process: ..Default::default() }], binaries: vec![proposal_binary], + ..Default::default() }, ); let proto = ProtoSandboxPolicy { @@ -4983,6 +5466,7 @@ process: run_as_group: "sandbox".to_string(), }), network_policies, + network_middlewares: vec![], }; let engine = OpaEngine::from_proto(&proto).expect("engine from proto"); let input = NetworkInput { @@ -5020,6 +5504,7 @@ process: path: "/usr/bin/python".to_string(), ..Default::default() }], + ..Default::default() }, ); let proto = ProtoSandboxPolicy { @@ -5037,6 +5522,7 @@ process: run_as_group: "sandbox".to_string(), }), network_policies, + network_middlewares: vec![], }; let engine = OpaEngine::from_proto(&proto).expect("engine from proto"); let input = NetworkInput { @@ -5090,6 +5576,7 @@ process: path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }, ); let proto = ProtoSandboxPolicy { @@ -5107,6 +5594,7 @@ process: run_as_group: "sandbox".to_string(), }), network_policies, + network_middlewares: vec![], }; let engine = OpaEngine::from_proto(&proto).expect("Failed to create engine from proto"); @@ -5320,6 +5808,7 @@ network_policies: path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }, ); let proto = ProtoSandboxPolicy { @@ -5337,6 +5826,7 @@ network_policies: run_as_group: "sandbox".to_string(), }), network_policies, + network_middlewares: vec![], }; let engine = OpaEngine::from_proto(&proto).unwrap(); // Port 443 @@ -6023,6 +6513,7 @@ network_policies: path: "/usr/bin/python3".to_string(), ..Default::default() }], + ..Default::default() }, ); @@ -6279,6 +6770,7 @@ network_policies: path: link_path, ..Default::default() }], + ..Default::default() }, ); let proto = ProtoSandboxPolicy { @@ -6296,6 +6788,7 @@ network_policies: run_as_group: "sandbox".to_string(), }), network_policies, + network_middlewares: vec![], }; // Build engine with our PID (symlink resolution will work via /proc/self/root/) @@ -6356,6 +6849,7 @@ network_policies: path: link_path, ..Default::default() }], + ..Default::default() }, ); let proto = ProtoSandboxPolicy { @@ -6373,6 +6867,7 @@ network_policies: run_as_group: "sandbox".to_string(), }), network_policies, + network_middlewares: vec![], }; // Initial load at pid=0 — no symlink expansion @@ -6415,6 +6910,133 @@ network_policies: assert!(eval_l7(&engine, &input)); } + #[test] + fn middleware_chain_orders_global_policy_endpoint_once() { + let data = r#" +network_middlewares: + - name: global-redactor + middleware: openshell/secrets + endpoints: + include: ["api.example.com"] + - name: policy-redactor + middleware: openshell/secrets + - name: endpoint-redactor + middleware: openshell/secrets +network_policies: + api: + name: api + middleware: ["global-redactor", "policy-redactor"] + endpoints: + - host: api.example.com + port: 443 + protocol: rest + enforcement: enforce + middleware: ["policy-redactor", "endpoint-redactor"] + rules: + - allow: { method: POST, path: "/v1/**" } + binaries: + - { path: /usr/bin/curl } +"#; + let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); + let input = NetworkInput { + host: "api.example.com".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/curl"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let (chain, _) = engine + .query_middleware_chain_with_generation(&input, "/v1/messages") + .unwrap(); + let names: Vec<_> = chain.iter().map(|entry| entry.name.as_str()).collect(); + assert_eq!( + names, + vec!["global-redactor", "policy-redactor", "endpoint-redactor"] + ); + } + + #[test] + fn middleware_policy_validation_rejects_bad_configs() { + let cases = [ + ( + "missing reference", + r#" +network_middlewares: + - name: redactor + middleware: openshell/secrets +network_policies: + api: + middleware: ["missing"] + endpoints: + - { host: api.example.com, port: 443 } + binaries: + - { path: /usr/bin/curl } +"#, + "unknown middleware config 'missing'", + ), + ( + "invalid on_error", + r#" +network_middlewares: + - name: redactor + middleware: openshell/secrets + on_error: maybe +"#, + "invalid on_error", + ), + ( + "duplicate names", + r#" +network_middlewares: + - name: redactor + middleware: openshell/secrets + - name: redactor + middleware: openshell/secrets +"#, + "duplicate middleware config 'redactor'", + ), + ( + "reserved builtin", + r#" +network_middlewares: + - name: sigv4 + middleware: openshell/sigv4 +"#, + "unsupported built-in", + ), + ( + "tls skip attachment", + r#" +network_middlewares: + - name: redactor + middleware: openshell/secrets +network_policies: + api: + endpoints: + - host: api.example.com + port: 443 + tls: skip + middleware: ["redactor"] + binaries: + - { path: /usr/bin/curl } +"#, + "tls: skip", + ), + ]; + + for (name, data, expected) in cases { + let err = match OpaEngine::from_strings(TEST_POLICY, data) { + Ok(_) => panic!("{name}: expected policy validation failure"), + Err(err) => err.to_string(), + }; + assert!( + err.contains(expected), + "{name}: expected {expected:?} in {err:?}" + ); + } + } + #[test] fn l7_head_denied_when_only_post_allowed() { let engine = OpaEngine::from_strings( diff --git a/crates/openshell-supervisor-network/src/policy_local.rs b/crates/openshell-supervisor-network/src/policy_local.rs index 3cbc31502..fa8029c72 100644 --- a/crates/openshell-supervisor-network/src/policy_local.rs +++ b/crates/openshell-supervisor-network/src/policy_local.rs @@ -1047,6 +1047,7 @@ fn network_rule_from_json( name: rule.name.unwrap_or_default(), endpoints, binaries, + middleware: Vec::new(), }) } @@ -1133,6 +1134,7 @@ fn network_endpoint_from_json( credential_signing: String::new(), signing_service: String::new(), signing_region: String::new(), + middleware: Vec::new(), }) } @@ -1829,6 +1831,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }), ..Default::default() }; @@ -1853,6 +1856,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() } } @@ -1916,6 +1920,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() })); }) }; diff --git a/crates/openshell-supervisor-network/src/proxy.rs b/crates/openshell-supervisor-network/src/proxy.rs index 0d2c8c025..af3331735 100644 --- a/crates/openshell-supervisor-network/src/proxy.rs +++ b/crates/openshell-supervisor-network/src/proxy.rs @@ -1183,6 +1183,7 @@ async fn handle_tcp_connection( &mut tls_upstream, &ctx, &generation_guard, + Some(&opa_engine), ) .await } @@ -1288,6 +1289,7 @@ async fn handle_tcp_connection( &mut upstream, &ctx, &generation_guard, + Some(&opa_engine), ) .await { diff --git a/proto/middleware.proto b/proto/middleware.proto new file mode 100644 index 000000000..d5d2ad48d --- /dev/null +++ b/proto/middleware.proto @@ -0,0 +1,95 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +syntax = "proto3"; + +package openshell.middleware.v1; + +import "google/protobuf/empty.proto"; +import "google/protobuf/struct.proto"; + +service SupervisorMiddleware { + rpc Describe(google.protobuf.Empty) returns (MiddlewareManifest); + rpc ValidateConfig(ValidateConfigRequest) returns (ValidateConfigResponse); + rpc EvaluateHttpRequest(HttpRequestEvaluation) returns (HttpRequestResult); +} + +message MiddlewareManifest { + string api_version = 1; + string name = 2; + string service_version = 3; + repeated MiddlewareBinding bindings = 4; +} + +message MiddlewareBinding { + string id = 1; + string operation = 2; + string phase = 3; +} + +message ValidateConfigRequest { + string api_version = 1; + string binding_id = 2; + google.protobuf.Struct config = 3; +} + +message ValidateConfigResponse { + bool valid = 1; + string reason = 2; +} + +message HttpRequestEvaluation { + string api_version = 1; + string binding_id = 2; + string phase = 3; + RequestContext context = 4; + google.protobuf.Struct config = 5; + HttpRequestTarget target = 6; + map headers = 7; + bytes body = 8; +} + +message RequestContext { + string request_id = 1; + string sandbox_id = 2; + Process originating_process = 3; +} + +message HttpRequestTarget { + string scheme = 1; + string host = 2; + uint32 port = 3; + string method = 4; + string path = 5; + string query = 6; +} + +message Process { + string binary = 1; + uint32 pid = 2; + repeated string ancestors = 3; +} + +enum Decision { + DECISION_UNSPECIFIED = 0; + DECISION_ALLOW = 1; + DECISION_DENY = 2; +} + +message Finding { + string type = 1; + string label = 2; + uint32 count = 3; + string confidence = 4; + string severity = 5; +} + +message HttpRequestResult { + Decision decision = 1; + string reason = 2; + bytes body = 3; + bool has_body = 4; + map add_headers = 5; + repeated Finding findings = 6; + map metadata = 7; +} diff --git a/proto/sandbox.proto b/proto/sandbox.proto index 8a5a59333..5d2bc31a5 100644 --- a/proto/sandbox.proto +++ b/proto/sandbox.proto @@ -5,6 +5,8 @@ syntax = "proto3"; package openshell.sandbox.v1; +import "google/protobuf/struct.proto"; + // Sandbox-supervisor configuration and policy messages. // // Conventions: @@ -25,6 +27,8 @@ message SandboxPolicy { ProcessPolicy process = 4; // Network access policies keyed by name (e.g. "claude_code", "gitlab"). map network_policies = 5; + // Reusable supervisor middleware configs for network egress. + repeated NetworkMiddlewareConfig network_middlewares = 6; } // Filesystem access policy. @@ -59,6 +63,27 @@ message NetworkPolicyRule { repeated NetworkEndpoint endpoints = 2; // Allowed binary identities. repeated NetworkBinary binaries = 3; + // Ordered middleware configs applied to every endpoint in this policy. + repeated string middleware = 4; +} + +// A reusable middleware config referenced by network policies/endpoints. +message NetworkMiddlewareConfig { + // Policy-local config name. + string name = 1; + // Built-in or registered middleware implementation name. + string middleware = 2; + // Service-specific configuration. + google.protobuf.Struct config = 3; + // Failure behavior: "fail_closed" (default) or "fail_open". + string on_error = 4; + // Optional global endpoint selector for this config. + MiddlewareEndpointSelector endpoints = 5; +} + +message MiddlewareEndpointSelector { + repeated string include = 1; + repeated string exclude = 2; } // A network endpoint (host + port) with optional L7 inspection config. @@ -143,6 +168,8 @@ message NetworkEndpoint { uint32 json_rpc_max_body_bytes = 22; // MCP-only policy and inspection options. Only used when protocol is "mcp". McpOptions mcp = 23; + // Ordered middleware configs applied to this endpoint after policy-level middleware. + repeated string middleware = 24; } // MCP options are grouped so MCP-specific policy can grow without adding more @@ -175,8 +202,6 @@ message McpOptions { // MCP-family methods at the method layer unless a tool-name policy narrows // tools/call. When unset or false, explicit method rules are required. optional bool allow_all_known_mcp_methods = 2; -} - // Trusted GraphQL operation classification. message GraphqlOperation { // Operation type: "query", "mutation", or "subscription". From e4e6f8fefc23a2d1745b272ac730e97c179e77be Mon Sep 17 00:00:00 2001 From: Piotr Mlocek Date: Fri, 26 Jun 2026 13:25:39 -0700 Subject: [PATCH 02/27] fix(supervisor-middleware): harden middleware relay handling Signed-off-by: Piotr Mlocek --- Cargo.lock | 1 + crates/openshell-cli/src/policy_update.rs | 1 + .../src/mechanistic_mapper.rs | 1 + crates/openshell-server/src/grpc/policy.rs | 45 +- .../src/builtins/mod.rs | 2 +- .../src/builtins/secrets.rs | 87 +++- .../src/lib.rs | 298 ++++++++++- .../openshell-supervisor-network/Cargo.toml | 1 + .../src/l7/relay.rs | 489 +++++++++++++++++- .../src/l7/rest.rs | 283 ++++++---- .../openshell-supervisor-network/src/opa.rs | 23 +- 11 files changed, 1059 insertions(+), 172 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 6f664acba..3ed95de90 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3982,6 +3982,7 @@ dependencies = [ "tokio-tungstenite 0.26.2", "tower-mcp-types", "tracing", + "tracing-subscriber", "uuid", "webpki-roots 1.0.7", ] diff --git a/crates/openshell-cli/src/policy_update.rs b/crates/openshell-cli/src/policy_update.rs index 1f1f64750..824b1dde0 100644 --- a/crates/openshell-cli/src/policy_update.rs +++ b/crates/openshell-cli/src/policy_update.rs @@ -65,6 +65,7 @@ pub fn build_policy_update_plan( ..Default::default() }) .collect(), + middleware: Vec::new(), }; merge_operations.push(PolicyMergeOperation { operation: Some(policy_merge_operation::Operation::AddRule(AddNetworkRule { diff --git a/crates/openshell-sandbox/src/mechanistic_mapper.rs b/crates/openshell-sandbox/src/mechanistic_mapper.rs index 8ee2fc37f..bb83ddb66 100644 --- a/crates/openshell-sandbox/src/mechanistic_mapper.rs +++ b/crates/openshell-sandbox/src/mechanistic_mapper.rs @@ -162,6 +162,7 @@ pub fn generate_proposals(summaries: &[DenialSummary]) -> Vec { name: rule_name.clone(), endpoints: vec![endpoint], binaries, + middleware: Vec::new(), }; // Compute confidence. diff --git a/crates/openshell-server/src/grpc/policy.rs b/crates/openshell-server/src/grpc/policy.rs index cc8ff0d2e..09e311bb2 100644 --- a/crates/openshell-server/src/grpc/policy.rs +++ b/crates/openshell-server/src/grpc/policy.rs @@ -5746,6 +5746,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; let submit = handle_submit_policy_analysis( @@ -5959,6 +5960,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; let submit = handle_submit_policy_analysis( @@ -6075,6 +6077,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; handle_submit_policy_analysis( @@ -6180,6 +6183,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; let mechanistic_submit = handle_submit_policy_analysis( &state, @@ -6257,6 +6261,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; let agent_submit = handle_submit_policy_analysis( &state, @@ -6384,6 +6389,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; handle_submit_policy_analysis( @@ -6484,6 +6490,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; handle_submit_policy_analysis( @@ -6584,6 +6591,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; handle_submit_policy_analysis( @@ -6677,6 +6685,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; handle_submit_policy_analysis( @@ -6761,6 +6770,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; handle_submit_policy_analysis( @@ -6849,6 +6859,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; handle_submit_policy_analysis( @@ -6940,6 +6951,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; handle_submit_policy_analysis( @@ -7026,6 +7038,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; let response = handle_submit_policy_analysis( @@ -7201,6 +7214,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; handle_submit_policy_analysis( @@ -7297,6 +7311,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; handle_submit_policy_analysis( @@ -7382,6 +7397,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; handle_submit_policy_analysis( @@ -7523,6 +7539,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; handle_submit_policy_analysis( @@ -7648,6 +7665,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; let step1 = handle_submit_policy_analysis( &state, @@ -7689,6 +7707,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; let step2 = handle_submit_policy_analysis( &state, @@ -7820,6 +7839,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; let submit_one = |rule_name: &str, rule: NetworkPolicyRule| { @@ -7928,6 +7948,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; let submit_one = || { let state = state.clone(); @@ -8028,6 +8049,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; let submit = handle_submit_policy_analysis( @@ -8159,6 +8181,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; handle_submit_policy_analysis( @@ -8357,6 +8380,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }, }; @@ -8385,6 +8409,7 @@ mod tests { path: "/usr/bin/node".to_string(), ..Default::default() }], + ..Default::default() }, }; @@ -8413,6 +8438,7 @@ mod tests { path: "/usr/bin/node".to_string(), ..Default::default() }], + ..Default::default() }, }; @@ -8440,6 +8466,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; let chunk = DraftChunkRecord { id: "chunk-1".to_string(), @@ -8508,6 +8535,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }, )) .collect(), @@ -8536,6 +8564,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; let chunk = DraftChunkRecord { id: "chunk-merge".to_string(), @@ -8609,6 +8638,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }, )) .collect(), @@ -8637,6 +8667,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; let chunk = DraftChunkRecord { id: "chunk-new".to_string(), @@ -8773,7 +8804,7 @@ mod tests { allowed_ips: vec!["127.0.0.1".to_string()], ..Default::default() }], - binaries: vec![], + ..Default::default() }; let result = validate_rule_not_always_blocked(&rule); assert!(result.is_err()); @@ -8794,7 +8825,7 @@ mod tests { allowed_ips: vec!["169.254.169.254".to_string()], ..Default::default() }], - binaries: vec![], + ..Default::default() }; let result = validate_rule_not_always_blocked(&rule); assert!(result.is_err()); @@ -8812,7 +8843,7 @@ mod tests { port: 80, ..Default::default() }], - binaries: vec![], + ..Default::default() }; let result = validate_rule_not_always_blocked(&rule); assert!(result.is_err()); @@ -8830,7 +8861,7 @@ mod tests { port: 8080, ..Default::default() }], - binaries: vec![], + ..Default::default() }; let result = validate_rule_not_always_blocked(&rule); assert!(result.is_err()); @@ -8848,7 +8879,7 @@ mod tests { port: 80, ..Default::default() }], - binaries: vec![], + ..Default::default() }; let result = validate_rule_not_always_blocked(&rule); assert!(result.is_err()); @@ -8896,7 +8927,7 @@ mod tests { allowed_ips: vec!["10.0.5.0/24".to_string()], ..Default::default() }], - binaries: vec![], + ..Default::default() }; let result = validate_rule_not_always_blocked(&rule); assert!(result.is_ok()); @@ -8913,7 +8944,7 @@ mod tests { port: 443, ..Default::default() }], - binaries: vec![], + ..Default::default() }; let result = validate_rule_not_always_blocked(&rule); assert!(result.is_ok()); diff --git a/crates/openshell-supervisor-middleware/src/builtins/mod.rs b/crates/openshell-supervisor-middleware/src/builtins/mod.rs index 60572d3e8..d91ee745e 100644 --- a/crates/openshell-supervisor-middleware/src/builtins/mod.rs +++ b/crates/openshell-supervisor-middleware/src/builtins/mod.rs @@ -1,4 +1,4 @@ // SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -pub(crate) mod secrets; +pub mod secrets; diff --git a/crates/openshell-supervisor-middleware/src/builtins/secrets.rs b/crates/openshell-supervisor-middleware/src/builtins/secrets.rs index 6c94eb439..572102559 100644 --- a/crates/openshell-supervisor-middleware/src/builtins/secrets.rs +++ b/crates/openshell-supervisor-middleware/src/builtins/secrets.rs @@ -2,6 +2,7 @@ // SPDX-License-Identifier: Apache-2.0 use std::collections::HashMap; +use std::sync::LazyLock; use miette::{Result, miette}; use openshell_core::proto::{Decision, Finding, HttpRequestEvaluation, HttpRequestResult}; @@ -9,7 +10,36 @@ use regex::Regex; use crate::BUILTIN_SECRETS; -pub(crate) fn validate_config(config: &prost_types::Struct) -> Result<()> { +/// A named secret-detection pattern. The `kind` is an audit-safe label that +/// flows into findings so operators can see *what* matched without seeing the +/// raw value. +struct SecretPattern { + kind: &'static str, + regex: Regex, +} + +impl SecretPattern { + fn new(kind: &'static str, pattern: &str) -> Self { + Self { + kind, + regex: Regex::new(pattern).expect("valid built-in secret redaction pattern"), + } + } +} + +/// Compiled once: recompiling per request would put regex construction on the +/// egress hot path. +static SECRET_PATTERNS: LazyLock<[SecretPattern; 2]> = LazyLock::new(|| { + [ + SecretPattern::new( + "keyword", + r#"(?i)(api[_-]?key|access[_-]?token|secret|password)(["']?\s*[:=]\s*["'])[^"',\s}]+(["']?)"#, + ), + SecretPattern::new("openai", r"(sk-[A-Za-z0-9_-]{16,})"), + ] +}); + +pub fn validate_config(config: &prost_types::Struct) -> Result<()> { let mode = config .fields .get("secrets") @@ -27,49 +57,54 @@ pub(crate) fn validate_config(config: &prost_types::Struct) -> Result<()> { Ok(()) } -pub(crate) fn evaluate_http_request( - evaluation: &HttpRequestEvaluation, -) -> Result { +pub fn evaluate_http_request(evaluation: &HttpRequestEvaluation) -> Result { let default_config = prost_types::Struct::default(); validate_config(evaluation.config.as_ref().unwrap_or(&default_config))?; let text = String::from_utf8(evaluation.body.clone()) .map_err(|_| miette!("{} requires UTF-8 request bodies", BUILTIN_SECRETS))?; - let (body, count) = redact_common_secrets(&text)?; + let (body, matches) = redact_common_secrets(&text); + let total: u32 = matches + .iter() + .fold(0u32, |acc, (_, count)| acc.saturating_add(*count)); let mut result = HttpRequestResult { decision: Decision::Allow as i32, reason: String::new(), body: body.into_bytes(), - has_body: count > 0, + has_body: !matches.is_empty(), add_headers: HashMap::new(), findings: Vec::new(), metadata: HashMap::new(), }; - if count > 0 { - result.findings.push(Finding { - r#type: "secret.common".into(), - label: "common secret pattern".into(), - count, - confidence: "medium".into(), - severity: "medium".into(), - }); + if !matches.is_empty() { + // One finding per matched pattern kind, so audit shows what matched. + for (kind, count) in &matches { + result.findings.push(Finding { + r#type: format!("secret.{kind}"), + label: format!("{kind} secret pattern"), + count: *count, + confidence: "medium".into(), + severity: "medium".into(), + }); + } result .metadata - .insert("secrets_redacted".into(), count.to_string()); + .insert("secrets_redacted".into(), total.to_string()); } Ok(result) } -fn redact_common_secrets(input: &str) -> Result<(String, u32)> { - let patterns = [ - r#"(?i)(api[_-]?key|access[_-]?token|secret|password)(["']?\s*[:=]\s*["'])[^"',\s}]+(["']?)"#, - r#"(sk-[A-Za-z0-9_-]{16,})"#, - ]; +/// Redact every configured secret pattern, returning the transformed text and +/// the per-kind match counts (only kinds that matched are included). +fn redact_common_secrets(input: &str) -> (String, Vec<(&'static str, u32)>) { let mut output = input.to_string(); - let mut count = 0u32; - for pattern in patterns { - let regex = Regex::new(pattern).map_err(|e| miette!("{e}"))?; - count = count.saturating_add(regex.find_iter(&output).count() as u32); - output = regex + let mut matches = Vec::new(); + for pattern in SECRET_PATTERNS.iter() { + let count = u32::try_from(pattern.regex.find_iter(&output).count()).unwrap_or(u32::MAX); + if count > 0 { + matches.push((pattern.kind, count)); + } + output = pattern + .regex .replace_all(&output, |captures: ®ex::Captures<'_>| { if captures.len() >= 4 { format!("{}{}[REDACTED]{}", &captures[1], &captures[2], &captures[3]) @@ -79,5 +114,5 @@ fn redact_common_secrets(input: &str) -> Result<(String, u32)> { }) .into_owned(); } - Ok((output, count)) + (output, matches) } diff --git a/crates/openshell-supervisor-middleware/src/lib.rs b/crates/openshell-supervisor-middleware/src/lib.rs index 7d9161fcf..b68d83c86 100644 --- a/crates/openshell-supervisor-middleware/src/lib.rs +++ b/crates/openshell-supervisor-middleware/src/lib.rs @@ -100,7 +100,7 @@ pub struct ChainOutcome { pub applied: Vec, } -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct NamespacedFinding { pub middleware: String, pub finding: Finding, @@ -112,6 +112,48 @@ pub struct MiddlewareInvocation { pub implementation: String, pub decision: Decision, pub transformed: bool, + /// True when the middleware could not be evaluated and `on_error` was applied + /// (service error, malformed/unsafe response, etc.). The `decision` reflects + /// the `on_error` outcome, not a decision the middleware actually returned. + pub failed: bool, +} + +enum OnErrorAction { + /// `fail_open`: skip this middleware, leaving the request unchanged. + FailOpen, + /// `fail_closed`: short-circuit the chain and deny with the given reason. + FailClosed(String), +} + +/// Apply a middleware entry's `on_error` policy after a failure (service error or +/// malformed response). Records a `failed` invocation for telemetry in both cases. +fn apply_on_error( + entry: &ChainEntry, + reason: &str, + applied: &mut Vec, +) -> OnErrorAction { + match entry.on_error { + OnError::FailOpen => { + applied.push(MiddlewareInvocation { + name: entry.name.clone(), + implementation: entry.implementation.clone(), + decision: Decision::Allow, + transformed: false, + failed: true, + }); + OnErrorAction::FailOpen + } + OnError::FailClosed => { + applied.push(MiddlewareInvocation { + name: entry.name.clone(), + implementation: entry.implementation.clone(), + decision: Decision::Deny, + transformed: false, + failed: true, + }); + OnErrorAction::FailClosed(format!("middleware_failed: {reason}")) + } + } } #[derive(Clone)] @@ -150,20 +192,33 @@ impl ChainRunner { .await { Ok(result) => result.into_inner(), - Err(err) => match entry.on_error { - OnError::FailOpen => { - applied.push(MiddlewareInvocation { - name: entry.name.clone(), - implementation: entry.implementation.clone(), - decision: Decision::Allow, - transformed: false, - }); - continue; + Err(err) => { + match apply_on_error(entry, &safe_reason(&err.to_string()), &mut applied) { + OnErrorAction::FailOpen => continue, + OnErrorAction::FailClosed(reason) => { + return Ok(ChainOutcome { + allowed: false, + reason, + body, + added_headers, + findings, + metadata, + applied, + }); + } } - OnError::FailClosed => { + } + }; + + // A result proposing unsafe header mutations is a malformed response: + // route it through `on_error` instead of applying any of it. + if validate_header_mutations(&headers, &result.add_headers).is_err() { + match apply_on_error(entry, "unsafe_response_headers", &mut applied) { + OnErrorAction::FailOpen => continue, + OnErrorAction::FailClosed(reason) => { return Ok(ChainOutcome { allowed: false, - reason: format!("middleware_failed: {}", safe_reason(&err.to_string())), + reason, body, added_headers, findings, @@ -171,17 +226,15 @@ impl ChainRunner { applied, }); } - }, - }; - - validate_header_mutations(&headers, &result.add_headers)?; + } + } for (name, value) in &result.add_headers { headers.insert(name.to_ascii_lowercase(), value.clone()); added_headers.insert(name.to_ascii_lowercase(), value.clone()); } let transformed = result.has_body; if result.has_body { - body = result.body.clone(); + result.body.clone_into(&mut body); } for finding in result.findings { findings.push(NamespacedFinding { @@ -200,6 +253,7 @@ impl ChainRunner { implementation: entry.implementation.clone(), decision: Decision::try_from(result.decision).unwrap_or(Decision::Unspecified), transformed, + failed: false, }); if result.decision == Decision::Deny as i32 { return Ok(ChainOutcome { @@ -264,7 +318,7 @@ fn validate_header_mutations( mutations: &HashMap, ) -> Result<()> { let mut seen = HashSet::new(); - for name in mutations.keys() { + for (name, value) in mutations { let lower = name.to_ascii_lowercase(); if !seen.insert(lower.clone()) || existing_headers.contains_key(&lower) { return Err(miette!( @@ -274,10 +328,27 @@ fn validate_header_mutations( if !is_safe_append_header(&lower) { return Err(miette!("middleware cannot append unsafe header '{name}'")); } + // Reject CR/LF and other control characters in the value: writing them + // verbatim into the upstream header block would enable header injection + // and request smuggling past the credential boundary. + if !is_safe_header_value(value) { + return Err(miette!( + "middleware cannot append header '{name}' with an unsafe value" + )); + } } Ok(()) } +/// A header value is safe to append only if it contains no control characters. +/// Horizontal tab, printable ASCII, and obs-text (>= 0x80) are permitted; CR, LF, +/// NUL, and other control bytes are rejected. +fn is_safe_header_value(value: &str) -> bool { + value + .bytes() + .all(|b| b == b'\t' || (0x20..=0x7e).contains(&b) || b >= 0x80) +} + fn is_safe_append_header(name: &str) -> bool { if name.is_empty() || name.contains(':') @@ -312,13 +383,12 @@ mod tests { name: name.into(), implementation: BUILTIN_SECRETS.into(), config: prost_types::Struct { - fields: [( + fields: std::iter::once(( "secrets".into(), prost_types::Value { kind: Some(prost_types::value::Kind::StringValue("redact".into())), }, - )] - .into_iter() + )) .collect(), }, on_error, @@ -427,11 +497,191 @@ mod tests { fn unsafe_header_mutation_is_rejected() { let err = validate_header_mutations( &BTreeMap::new(), - &[("Authorization".into(), "Bearer nope".into())] - .into_iter() - .collect(), + &std::iter::once(("Authorization".into(), "Bearer nope".into())).collect(), ) .expect_err("unsafe header"); assert!(err.to_string().contains("unsafe header")); } + + #[test] + fn header_value_with_crlf_is_rejected() { + // A safe header *name* with a CRLF-bearing value must still be rejected, + // otherwise it would inject extra headers into the upstream request. + let err = validate_header_mutations( + &BTreeMap::new(), + &std::iter::once(( + "x-openshell-middleware-inject".into(), + "ok\r\nAuthorization: Bearer evil".into(), + )) + .collect(), + ) + .expect_err("crlf value"); + assert!(err.to_string().contains("unsafe value")); + } + + /// A mock middleware that returns a fixed, caller-supplied result for every + /// evaluation. Used to exercise chain behavior the built-in cannot produce + /// (explicit deny, metadata, findings, unsafe header mutations). + struct ScriptedService { + result: openshell_core::proto::HttpRequestResult, + } + + #[tonic::async_trait] + impl SupervisorMiddleware for ScriptedService { + async fn describe( + &self, + _request: Request<()>, + ) -> std::result::Result< + tonic::Response, + tonic::Status, + > { + Ok(tonic::Response::new( + openshell_core::proto::MiddlewareManifest::default(), + )) + } + + async fn validate_config( + &self, + _request: Request, + ) -> std::result::Result< + tonic::Response, + tonic::Status, + > { + Ok(tonic::Response::new( + openshell_core::proto::ValidateConfigResponse { + valid: true, + reason: String::new(), + }, + )) + } + + async fn evaluate_http_request( + &self, + _request: Request, + ) -> std::result::Result< + tonic::Response, + tonic::Status, + > { + Ok(tonic::Response::new(self.result.clone())) + } + } + + fn allow_result() -> openshell_core::proto::HttpRequestResult { + openshell_core::proto::HttpRequestResult { + decision: Decision::Allow as i32, + reason: String::new(), + body: Vec::new(), + has_body: false, + add_headers: HashMap::new(), + findings: Vec::new(), + metadata: HashMap::new(), + } + } + + #[tokio::test] + async fn deny_decision_short_circuits_chain() { + let runner = ChainRunner::new(Arc::new(ScriptedService { + result: openshell_core::proto::HttpRequestResult { + decision: Decision::Deny as i32, + reason: "blocked_by_policy".into(), + ..allow_result() + }, + })); + let outcome = runner + .evaluate( + &[ + entry("first", OnError::FailClosed), + entry("second", OnError::FailClosed), + ], + input("hello"), + ) + .await + .expect("evaluate"); + assert!(!outcome.allowed); + assert_eq!(outcome.reason, "blocked_by_policy"); + // The deny short-circuits the chain: the second middleware never runs. + assert_eq!(outcome.applied.len(), 1); + assert_eq!(outcome.applied[0].decision, Decision::Deny); + assert!(!outcome.applied[0].failed); + } + + #[tokio::test] + async fn metadata_and_findings_are_namespaced_per_config() { + let runner = ChainRunner::new(Arc::new(ScriptedService { + result: openshell_core::proto::HttpRequestResult { + findings: vec![Finding { + r#type: "pii.email".into(), + label: "email address".into(), + count: 2, + confidence: "high".into(), + severity: "medium".into(), + }], + metadata: std::iter::once(("sensitivity".to_string(), "high".to_string())) + .collect(), + ..allow_result() + }, + })); + let outcome = runner + .evaluate( + &[ + entry("alpha", OnError::FailClosed), + entry("beta", OnError::FailClosed), + ], + input("hello"), + ) + .await + .expect("evaluate"); + assert!(outcome.allowed); + // Metadata is bucketed under each config's local name, so two configs + // emitting the same key do not collide. + assert_eq!(outcome.metadata["alpha"]["sensitivity"], "high"); + assert_eq!(outcome.metadata["beta"]["sensitivity"], "high"); + // Findings are tagged with the emitting config's name. + assert_eq!(outcome.findings.len(), 2); + assert_eq!(outcome.findings[0].middleware, "alpha"); + assert_eq!(outcome.findings[1].middleware, "beta"); + assert_eq!(outcome.findings[0].finding.r#type, "pii.email"); + assert_eq!(outcome.findings[0].finding.count, 2); + } + + fn unsafe_header_service() -> ScriptedService { + ScriptedService { + result: openshell_core::proto::HttpRequestResult { + add_headers: std::iter::once(( + "x-openshell-middleware-inject".to_string(), + "ok\r\nHost: evil".to_string(), + )) + .collect(), + ..allow_result() + }, + } + } + + #[tokio::test] + async fn malformed_response_headers_fail_closed_denies() { + let runner = ChainRunner::new(Arc::new(unsafe_header_service())); + let outcome = runner + .evaluate(&[entry("redact", OnError::FailClosed)], input("hello")) + .await + .expect("evaluate"); + assert!(!outcome.allowed); + assert!(outcome.reason.starts_with("middleware_failed:")); + assert!(outcome.applied.iter().any(|inv| inv.failed)); + // The unsafe header is never forwarded. + assert!(outcome.added_headers.is_empty()); + } + + #[tokio::test] + async fn malformed_response_headers_fail_open_continues() { + let runner = ChainRunner::new(Arc::new(unsafe_header_service())); + let outcome = runner + .evaluate(&[entry("redact", OnError::FailOpen)], input("hello")) + .await + .expect("evaluate"); + assert!(outcome.allowed); + assert_eq!(outcome.body, b"hello"); + assert!(outcome.added_headers.is_empty()); + assert_eq!(outcome.applied.len(), 1); + assert!(outcome.applied[0].failed); + } } diff --git a/crates/openshell-supervisor-network/Cargo.toml b/crates/openshell-supervisor-network/Cargo.toml index fd8fad5f7..b8cae5113 100644 --- a/crates/openshell-supervisor-network/Cargo.toml +++ b/crates/openshell-supervisor-network/Cargo.toml @@ -55,6 +55,7 @@ tempfile = "3" temp-env = "0.3" tokio-tungstenite = { workspace = true } futures = { workspace = true } +tracing-subscriber = { workspace = true } [target.'cfg(unix)'.dev-dependencies] libc = "0.2" diff --git a/crates/openshell-supervisor-network/src/l7/relay.rs b/crates/openshell-supervisor-network/src/l7/relay.rs index 4d501d0a3..6e5c3c4e9 100644 --- a/crates/openshell-supervisor-network/src/l7/relay.rs +++ b/crates/openshell-supervisor-network/src/l7/relay.rs @@ -783,9 +783,18 @@ async fn apply_middleware_chain( if chain.is_empty() { return Ok(MiddlewareApplyResult::Allowed(req)); } - let buffered = - crate::l7::rest::buffer_request_body_for_middleware(&req, client, Some(generation_guard)) - .await?; + let buffered = match crate::l7::rest::buffer_request_body_for_middleware( + &req, + client, + Some(generation_guard), + ) + .await? + { + crate::l7::rest::BufferResult::Buffered(buffered) => buffered, + crate::l7::rest::BufferResult::OverCapacity { recoverable } => { + return Ok(resolve_unbuffered_body(ctx, req, &chain, recoverable)); + } + }; let headers = safe_middleware_headers(&buffered.headers)?; let input = openshell_supervisor_middleware::HttpRequestInput { request_id: uuid::Uuid::new_v4().to_string(), @@ -819,6 +828,52 @@ async fn apply_middleware_chain( } } +/// Apply the chain's `on_error` policy when the request body cannot be buffered +/// for inspection because it exceeds the size cap. The RFC treats an unbufferable +/// body as an `on_error` event: it is denied unless every attached middleware is +/// `fail_open`, and passing it through is only safe when no bytes were consumed. +fn resolve_unbuffered_body( + ctx: &L7EvalContext, + req: crate::l7::provider::L7Request, + chain: &[openshell_supervisor_middleware::ChainEntry], + recoverable: bool, +) -> MiddlewareApplyResult { + let all_fail_open = chain + .iter() + .all(|entry| entry.on_error == openshell_supervisor_middleware::OnError::FailOpen); + if recoverable && all_fail_open { + emit_middleware_body_unavailable(ctx, false); + return MiddlewareApplyResult::Allowed(req); + } + emit_middleware_body_unavailable(ctx, true); + MiddlewareApplyResult::Denied("middleware_failed: request_body_over_capacity".into()) +} + +fn emit_middleware_body_unavailable(ctx: &L7EvalContext, denied: bool) { + let event = DetectionFindingBuilder::new(openshell_ocsf::ctx::ctx()) + .severity(if denied { + SeverityId::High + } else { + SeverityId::Medium + }) + .finding_info(FindingInfo::new( + "openshell.middleware.body_unavailable", + "Supervisor middleware could not inspect request body", + )) + .evidence_pairs(&[ + ("policy", ctx.policy_name.as_str()), + ("host", ctx.host.as_str()), + ("disposition", if denied { "denied" } else { "fail_open" }), + ]) + .message(if denied { + "Request body exceeded middleware inspection cap; denied" + } else { + "Request body exceeded middleware inspection cap; passed through (fail_open)" + }) + .build(); + ocsf_emit!(event); +} + fn safe_middleware_headers(headers: &[u8]) -> Result> { let header_str = std::str::from_utf8(headers).map_err(|_| miette!("HTTP headers contain invalid UTF-8"))?; @@ -885,14 +940,37 @@ fn emit_middleware_events( .dst_endpoint(Endpoint::from_domain(&ctx.host, ctx.port)) .firewall_rule(&ctx.policy_name, "middleware") .message(format!( - "MIDDLEWARE {} {} decision={:?} transformed={}", + "MIDDLEWARE {} {} decision={:?} transformed={} failed={}", invocation.name, invocation.implementation, invocation.decision, - invocation.transformed + invocation.transformed, + invocation.failed )) .build(); ocsf_emit!(event); + + // A middleware that failed but was bypassed under `fail_open` is an + // enforcement failure operators must be able to alert on, even though the + // request proceeded. + if invocation.failed && allowed { + let event = DetectionFindingBuilder::new(openshell_ocsf::ctx::ctx()) + .severity(SeverityId::Medium) + .finding_info(FindingInfo::new( + "openshell.middleware.failure", + "Supervisor middleware failed open", + )) + .evidence_pairs(&[ + ("middleware", invocation.name.as_str()), + ("implementation", invocation.implementation.as_str()), + ]) + .message(format!( + "Middleware {} failed and was bypassed (fail_open)", + invocation.name + )) + .build(); + ocsf_emit!(event); + } } if !outcome.allowed && outcome.reason.starts_with("middleware_failed:") { let event = DetectionFindingBuilder::new(openshell_ocsf::ctx::ctx()) @@ -2658,6 +2736,407 @@ network_policies: .unwrap(); } + #[tokio::test] + async fn l7_rest_middleware_over_capacity_fails_closed() { + let (config, tunnel_engine, ctx) = + middleware_relay_context("openshell/secrets", "fail_closed"); + let (mut app, mut relay_client) = tokio::io::duplex(8192); + let (mut relay_upstream, mut upstream) = tokio::io::duplex(8192); + let relay = tokio::spawn(async move { + relay_with_inspection( + &config, + tunnel_engine, + &mut relay_client, + &mut relay_upstream, + &ctx, + ) + .await + }); + + // A declared body far above the 256 KiB inspection cap must be denied + // (fail-closed) before the body is read or reaches the upstream. + let request = format!( + "POST /v1/messages HTTP/1.1\r\nHost: api.example.test\r\nContent-Length: {}\r\nConnection: close\r\n\r\n", + 300 * 1024 + ); + app.write_all(request.as_bytes()).await.unwrap(); + + let mut response = [0u8; 512]; + let n = tokio::time::timeout(std::time::Duration::from_secs(1), app.read(&mut response)) + .await + .expect("denial should reach client") + .unwrap(); + let response = String::from_utf8_lossy(&response[..n]); + assert!(response.contains("403 Forbidden")); + assert!(response.contains("request_body_over_capacity")); + + let mut upstream_request = [0u8; 32]; + let result = tokio::time::timeout( + std::time::Duration::from_millis(100), + upstream.read(&mut upstream_request), + ) + .await; + assert!( + matches!(result, Err(_) | Ok(Ok(0))), + "upstream should not receive request bytes" + ); + + drop(app); + tokio::time::timeout(std::time::Duration::from_secs(1), relay) + .await + .expect("relay should finish") + .unwrap() + .unwrap(); + } + + #[test] + fn over_capacity_resolution_honors_on_error() { + use openshell_supervisor_middleware::{ChainEntry, OnError}; + + let ctx = L7EvalContext { + host: "api.example.test".into(), + port: 443, + policy_name: "p".into(), + binary_path: "/usr/bin/curl".into(), + ancestors: vec![], + cmdline_paths: vec![], + secret_resolver: None, + activity_tx: None, + dynamic_credentials: None, + token_grant_resolver: None, + }; + let req = || crate::l7::provider::L7Request { + action: "POST".into(), + target: "/v1".into(), + query_params: std::collections::HashMap::new(), + raw_header: Vec::new(), + body_length: crate::l7::provider::BodyLength::None, + }; + let fail_open = ChainEntry { + name: "m".into(), + implementation: "openshell/secrets".into(), + config: prost_types::Struct::default(), + on_error: OnError::FailOpen, + }; + let fail_closed = ChainEntry { + on_error: OnError::FailClosed, + ..fail_open.clone() + }; + + // Recoverable (Content-Length over cap, nothing consumed) + all fail-open + // -> stream through unprocessed. + assert!(matches!( + resolve_unbuffered_body(&ctx, req(), std::slice::from_ref(&fail_open), true), + MiddlewareApplyResult::Allowed(_) + )); + // Any fail-closed entry -> deny. + assert!(matches!( + resolve_unbuffered_body(&ctx, req(), &[fail_open.clone(), fail_closed], true), + MiddlewareApplyResult::Denied(_) + )); + // Not recoverable (chunked overflow already consumed bytes) -> deny even + // when every entry is fail-open. + assert!(matches!( + resolve_unbuffered_body(&ctx, req(), &[fail_open], false), + MiddlewareApplyResult::Denied(_) + )); + } + + /// Tracing layer that captures emitted `OcsfEvent`s for assertions. + struct OcsfCaptureLayer(Arc>>); + + impl tracing_subscriber::Layer for OcsfCaptureLayer { + fn on_event( + &self, + event: &tracing::Event<'_>, + _ctx: tracing_subscriber::layer::Context<'_, S>, + ) { + if event.metadata().target() == openshell_ocsf::OCSF_TARGET + && let Some(ocsf_event) = openshell_ocsf::clone_current_event() + { + self.0.lock().unwrap().push(ocsf_event); + } + } + } + + #[test] + fn middleware_ocsf_events_are_audit_safe() { + use openshell_supervisor_middleware::{ + ChainOutcome, MiddlewareInvocation, NamespacedFinding, + }; + use tracing_subscriber::layer::SubscriberExt; + + const RAW_SECRET: &str = "sk-RAWSECRETVALUE0123456789"; + + let events = Arc::new(std::sync::Mutex::new(Vec::new())); + let subscriber = tracing_subscriber::registry().with(OcsfCaptureLayer(Arc::clone(&events))); + let _guard = tracing::subscriber::set_default(subscriber); + + let ctx = L7EvalContext { + host: "api.example.test".into(), + port: 443, + policy_name: "rest_api".into(), + binary_path: "/usr/bin/curl".into(), + ancestors: vec![], + cmdline_paths: vec![], + secret_resolver: None, + activity_tx: None, + dynamic_credentials: None, + token_grant_resolver: None, + }; + let req = crate::l7::provider::L7Request { + action: "POST".into(), + target: "/v1/messages".into(), + query_params: std::collections::HashMap::new(), + raw_header: Vec::new(), + body_length: crate::l7::provider::BodyLength::None, + }; + let outcome = ChainOutcome { + allowed: true, + reason: String::new(), + // The transformed body still holds the raw secret; emission must never + // serialize it. + body: format!(r#"{{"api_key":"{RAW_SECRET}"}}"#).into_bytes(), + added_headers: BTreeMap::new(), + findings: vec![NamespacedFinding { + middleware: "redact-secrets".into(), + finding: openshell_core::proto::Finding { + r#type: "secret.common".into(), + label: "common secret pattern".into(), + count: 1, + confidence: "medium".into(), + severity: "medium".into(), + }, + }], + metadata: BTreeMap::new(), + applied: vec![MiddlewareInvocation { + name: "redact-secrets".into(), + implementation: "openshell/secrets".into(), + decision: openshell_core::proto::Decision::Allow, + transformed: true, + failed: false, + }], + }; + + emit_middleware_events(&ctx, &req, &outcome); + + let captured = events.lock().unwrap(); + // Per-invocation decisions are HTTP Activity (class 4002). + assert!( + captured.iter().any(|e| e.class_uid() == 4002), + "expected an HTTP Activity event for the middleware invocation" + ); + // Findings are Detection Finding (class 2004) with the finding's severity. + let finding_event = captured + .iter() + .find(|e| e.class_uid() == 2004) + .expect("expected a Detection Finding event"); + assert_eq!(finding_event.base().severity, SeverityId::Medium); + + // No raw payload material may appear in any emitted event. + let serialized = serde_json::to_string(&*captured).expect("serialize events"); + assert!( + !serialized.contains(RAW_SECRET), + "raw secret leaked into OCSF events: {serialized}" + ); + // Safe finding metadata is still present. + assert!(serialized.contains("secret.common")); + } + + #[tokio::test] + async fn passthrough_relay_runs_middleware_redaction() { + // A no-protocol endpoint takes the credential-injection passthrough path; + // policy-level middleware must still inspect and redact its body. + let data = r#" +network_middlewares: + - name: request-middleware + middleware: openshell/secrets + on_error: fail_closed +network_policies: + passthrough_api: + name: passthrough_api + middleware: ["request-middleware"] + endpoints: + - host: api.example.test + port: 8080 + binaries: + - { path: /usr/bin/curl } +"#; + let engine = Arc::new(OpaEngine::from_strings(TEST_POLICY, data).unwrap()); + let generation_guard = engine + .generation_guard(engine.current_generation()) + .unwrap(); + let ctx = L7EvalContext { + host: "api.example.test".into(), + port: 8080, + policy_name: "passthrough_api".into(), + binary_path: "/usr/bin/curl".into(), + ancestors: vec![], + cmdline_paths: vec![], + secret_resolver: None, + activity_tx: None, + dynamic_credentials: None, + token_grant_resolver: None, + }; + + let (mut app, mut relay_client) = tokio::io::duplex(8192); + let (mut relay_upstream, mut upstream) = tokio::io::duplex(8192); + let engine_task = Arc::clone(&engine); + let relay = tokio::spawn(async move { + relay_passthrough_with_credentials( + &mut relay_client, + &mut relay_upstream, + &ctx, + &generation_guard, + Some(engine_task.as_ref()), + ) + .await + }); + + let body = br#"{"api_key":"sk-1234567890abcdef"}"#; + let request = format!( + "POST /v1/messages HTTP/1.1\r\nHost: api.example.test\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}", + body.len(), + std::str::from_utf8(body).unwrap() + ); + app.write_all(request.as_bytes()).await.unwrap(); + + let mut upstream_request = [0u8; 1024]; + let n = tokio::time::timeout( + std::time::Duration::from_secs(1), + upstream.read(&mut upstream_request), + ) + .await + .expect("request should reach upstream") + .unwrap(); + let upstream_request = String::from_utf8_lossy(&upstream_request[..n]); + assert!( + upstream_request.contains(r#""api_key":"[REDACTED]""#), + "unexpected upstream request: {upstream_request:?}" + ); + assert!(!upstream_request.contains("sk-1234567890abcdef")); + + upstream + .write_all(b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\nConnection: close\r\n\r\n") + .await + .unwrap(); + let mut client_response = [0u8; 512]; + let n = tokio::time::timeout( + std::time::Duration::from_secs(1), + app.read(&mut client_response), + ) + .await + .expect("response should reach client") + .unwrap(); + assert!(String::from_utf8_lossy(&client_response[..n]).contains("204 No Content")); + drop(app); + tokio::time::timeout(std::time::Duration::from_secs(1), relay) + .await + .expect("relay should finish") + .unwrap() + .unwrap(); + } + + #[tokio::test] + async fn websocket_upgrade_request_is_inspected_and_denied() { + // The WebSocket upgrade handshake is an HTTP request the hook can inspect + // and deny: a fail-closed middleware blocks the upgrade before it is + // forwarded. + let data = r#" +network_middlewares: + - name: request-middleware + middleware: example/unavailable + on_error: fail_closed +network_policies: + ws_api: + name: ws_api + middleware: ["request-middleware"] + endpoints: + - host: gateway.example.test + port: 443 + protocol: websocket + enforcement: enforce + rules: + - allow: + method: GET + path: "/ws" + binaries: + - { path: /usr/bin/node } +"#; + let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); + let input = NetworkInput { + host: "gateway.example.test".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/node"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let (endpoint_config, generation) = engine + .query_endpoint_config_with_generation(&input) + .unwrap(); + let config = crate::l7::parse_l7_config(&endpoint_config.unwrap()).unwrap(); + let tunnel_engine = engine.clone_engine_for_tunnel(generation).unwrap(); + let ctx = L7EvalContext { + host: "gateway.example.test".into(), + port: 443, + policy_name: "ws_api".into(), + binary_path: "/usr/bin/node".into(), + ancestors: vec![], + cmdline_paths: vec![], + secret_resolver: None, + activity_tx: None, + dynamic_credentials: None, + token_grant_resolver: None, + }; + + let (mut app, mut relay_client) = tokio::io::duplex(8192); + let (mut relay_upstream, mut upstream) = tokio::io::duplex(8192); + let relay = tokio::spawn(async move { + relay_with_inspection( + &config, + tunnel_engine, + &mut relay_client, + &mut relay_upstream, + &ctx, + ) + .await + }); + + app.write_all( + b"GET /ws HTTP/1.1\r\nHost: gateway.example.test\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n", + ) + .await + .unwrap(); + + let mut response = [0u8; 512]; + let n = tokio::time::timeout(std::time::Duration::from_secs(1), app.read(&mut response)) + .await + .expect("denial should reach client") + .unwrap(); + let response = String::from_utf8_lossy(&response[..n]); + assert!(response.contains("403 Forbidden")); + assert!(response.contains("middleware_failed")); + + let mut upstream_request = [0u8; 32]; + let result = tokio::time::timeout( + std::time::Duration::from_millis(100), + upstream.read(&mut upstream_request), + ) + .await; + assert!( + matches!(result, Err(_) | Ok(Ok(0))), + "upstream should not receive the upgrade request" + ); + + drop(app); + tokio::time::timeout(std::time::Duration::from_secs(1), relay) + .await + .expect("relay should finish") + .unwrap() + .unwrap(); + } + #[tokio::test] async fn passthrough_relay_injects_token_grant_authorization_header() { let (generation_guard, ctx, fixture) = diff --git a/crates/openshell-supervisor-network/src/l7/rest.rs b/crates/openshell-supervisor-network/src/l7/rest.rs index 1a4036abd..2c85cacf6 100644 --- a/crates/openshell-supervisor-network/src/l7/rest.rs +++ b/crates/openshell-supervisor-network/src/l7/rest.rs @@ -774,11 +774,23 @@ pub(crate) struct BufferedRequestBody { pub(crate) body: Vec, } +/// Result of attempting to buffer a request body for middleware inspection. +pub(crate) enum BufferResult { + /// The full body was buffered within the size cap. + Buffered(BufferedRequestBody), + /// The body exceeded the inspection cap. `recoverable` is true when no body + /// bytes were consumed yet (a declared `Content-Length` over the cap), so the + /// request can still be streamed through unprocessed under fail-open. It is + /// false once bytes have been consumed (chunked overflow), where denying is + /// the only safe outcome. + OverCapacity { recoverable: bool }, +} + pub(crate) async fn buffer_request_body_for_middleware( req: &L7Request, client: &mut C, generation_guard: Option<&PolicyGenerationGuard>, -) -> Result { +) -> Result { let header_end = req .raw_header .windows(4) @@ -787,17 +799,19 @@ pub(crate) async fn buffer_request_body_for_middleware( let headers = req.raw_header[..header_end].to_vec(); let already_read = &req.raw_header[header_end..]; match req.body_length { - BodyLength::None => Ok(BufferedRequestBody { + BodyLength::None => Ok(BufferResult::Buffered(BufferedRequestBody { headers, body: already_read.to_vec(), - }), + })), BodyLength::ContentLength(len) => { - let len = usize::try_from(len) - .map_err(|_| miette!("request body is too large for middleware"))?; + // The declared length is known before any further reads, so an + // over-cap body here has not consumed the stream and can be passed + // through unprocessed if every middleware is fail-open. + let Ok(len) = usize::try_from(len) else { + return Ok(BufferResult::OverCapacity { recoverable: true }); + }; if len > MAX_MIDDLEWARE_BODY_BYTES { - return Err(miette!( - "middleware buffers at most {MAX_MIDDLEWARE_BODY_BYTES} request body bytes" - )); + return Ok(BufferResult::OverCapacity { recoverable: true }); } let initial_len = already_read.len().min(len); let mut body = Vec::with_capacity(len); @@ -818,11 +832,21 @@ pub(crate) async fn buffer_request_body_for_middleware( body.extend_from_slice(&buf[..n]); remaining -= n; } - Ok(BufferedRequestBody { headers, body }) + Ok(BufferResult::Buffered(BufferedRequestBody { + headers, + body, + })) } BodyLength::Chunked => { - let body = collect_chunked_body(client, already_read, generation_guard).await?; - Ok(BufferedRequestBody { headers, body }) + // Chunked bodies are decoded incrementally into the payload bytes + // middleware expects. On overflow, we have already consumed wire + // bytes from the client stream and cannot re-enter the normal raw + // relay path without a separate splice-through buffer. + Ok(collect_chunked_body(client, already_read, generation_guard) + .await + .map_or(BufferResult::OverCapacity { recoverable: false }, |body| { + BufferResult::Buffered(BufferedRequestBody { headers, body }) + })) } } } @@ -835,7 +859,7 @@ pub(crate) fn rebuild_request_with_buffered_body( ) -> Result { let mut header_bytes = set_content_length(headers, body.len())?; header_bytes = strip_header(&header_bytes, "transfer-encoding")?; - header_bytes = append_headers(&header_bytes, add_headers)?; + header_bytes = append_headers(&header_bytes, add_headers); header_bytes.extend_from_slice(body); Ok(L7Request { action: req.action.clone(), @@ -900,15 +924,11 @@ async fn collect_and_rewrite_request_body( } BodyLength::Chunked => { let body = collect_chunked_body(client, already_read, generation_guard).await?; - if body_bytes_contain_reserved_marker(&body) { - return Err(miette!( - "request body credential rewrite does not support chunked bodies containing credential placeholders" - )); - } - Ok(PreparedRequestBody { - headers: rewritten_headers.to_vec(), - body, - }) + let (mut headers, body) = + rewrite_buffered_body(rewritten_headers, original_header_str, body, resolver)?; + headers = set_content_length(&headers, body.len())?; + headers = strip_header(&headers, "transfer-encoding")?; + Ok(PreparedRequestBody { headers, body }) } } } @@ -1076,37 +1096,15 @@ async fn collect_chunked_body( already_read: &[u8], generation_guard: Option<&PolicyGenerationGuard>, ) -> Result> { - let mut read_buf = [0u8; RELAY_BUF_SIZE]; - let mut parse_buf = Vec::from(already_read); - let mut pos = 0usize; + let mut buffered_pos = 0usize; + let mut body = Vec::new(); loop { - if parse_buf.len() > MAX_REWRITE_BODY_BYTES { - return Err(miette!( - "request body credential rewrite buffers at most {MAX_REWRITE_BODY_BYTES} bytes" - )); - } - - let size_line_end = loop { - if let Some(end) = find_crlf(&parse_buf, pos) { - break end; - } - let n = client.read(&mut read_buf).await.into_diagnostic()?; - if n == 0 { - return Err(miette!("Chunked body ended before chunk-size line")); - } - if let Some(guard) = generation_guard { - guard.ensure_current()?; - } - parse_buf.extend_from_slice(&read_buf[..n]); - if parse_buf.len() > MAX_REWRITE_BODY_BYTES { - return Err(miette!( - "request body credential rewrite buffers at most {MAX_REWRITE_BODY_BYTES} bytes" - )); - } - }; - - let size_line = std::str::from_utf8(&parse_buf[pos..size_line_end]) + let size_line = + read_chunked_line(client, already_read, &mut buffered_pos, generation_guard) + .await + .map_err(|e| miette!("Chunked body ended before chunk-size line: {e}"))?; + let size_line = std::str::from_utf8(&size_line) .into_diagnostic() .map_err(|_| miette!("Invalid UTF-8 in chunk-size line"))?; let size_token = size_line @@ -1117,64 +1115,109 @@ async fn collect_chunked_body( let chunk_size = usize::from_str_radix(size_token, 16) .into_diagnostic() .map_err(|_| miette!("Invalid chunk size token: {size_token:?}"))?; - pos = size_line_end + 2; if chunk_size == 0 { loop { - let trailer_end = loop { - if let Some(end) = find_crlf(&parse_buf, pos) { - break end; - } - let n = client.read(&mut read_buf).await.into_diagnostic()?; - if n == 0 { - return Err(miette!("Chunked body ended before trailer terminator")); - } - if let Some(guard) = generation_guard { - guard.ensure_current()?; - } - parse_buf.extend_from_slice(&read_buf[..n]); - if parse_buf.len() > MAX_REWRITE_BODY_BYTES { - return Err(miette!( - "request body credential rewrite buffers at most {MAX_REWRITE_BODY_BYTES} bytes" - )); - } - }; - let trailer_line = &parse_buf[pos..trailer_end]; - pos = trailer_end + 2; + let trailer_line = + read_chunked_line(client, already_read, &mut buffered_pos, generation_guard) + .await + .map_err(|e| { + miette!("Chunked body ended before trailer terminator: {e}") + })?; if trailer_line.is_empty() { - return Ok(parse_buf); + return Ok(body); } } } - let chunk_end = pos - .checked_add(chunk_size) - .ok_or_else(|| miette!("Chunk size overflow"))?; - let chunk_with_crlf_end = chunk_end - .checked_add(2) - .ok_or_else(|| miette!("Chunk size overflow"))?; - while parse_buf.len() < chunk_with_crlf_end { - let n = client.read(&mut read_buf).await.into_diagnostic()?; - if n == 0 { - return Err(miette!("Chunked body ended mid-chunk")); - } - if let Some(guard) = generation_guard { - guard.ensure_current()?; - } - parse_buf.extend_from_slice(&read_buf[..n]); - if parse_buf.len() > MAX_REWRITE_BODY_BYTES { - return Err(miette!( - "request body credential rewrite buffers at most {MAX_REWRITE_BODY_BYTES} bytes" - )); - } + if body.len().saturating_add(chunk_size) > MAX_REWRITE_BODY_BYTES { + return Err(miette!( + "request body credential rewrite buffers at most {MAX_REWRITE_BODY_BYTES} bytes" + )); } - if &parse_buf[chunk_end..chunk_with_crlf_end] != b"\r\n" { + read_buffered_exact( + client, + already_read, + &mut buffered_pos, + chunk_size, + &mut body, + generation_guard, + ) + .await + .map_err(|e| miette!("Chunked body ended mid-chunk: {e}"))?; + + let mut chunk_crlf = Vec::with_capacity(2); + read_buffered_exact( + client, + already_read, + &mut buffered_pos, + 2, + &mut chunk_crlf, + generation_guard, + ) + .await + .map_err(|e| miette!("Chunked body ended before chunk terminator: {e}"))?; + if chunk_crlf.as_slice() != b"\r\n" { return Err(miette!("Chunk missing terminating CRLF")); } - pos = chunk_with_crlf_end; } } +async fn read_chunked_line( + client: &mut C, + already_read: &[u8], + buffered_pos: &mut usize, + generation_guard: Option<&PolicyGenerationGuard>, +) -> Result> { + let mut line = Vec::new(); + loop { + let byte = read_buffered_byte(client, already_read, buffered_pos, generation_guard).await?; + line.push(byte); + if line.len() > MAX_REWRITE_BODY_BYTES { + return Err(miette!( + "request body credential rewrite buffers at most {MAX_REWRITE_BODY_BYTES} bytes" + )); + } + if line.ends_with(b"\r\n") { + line.truncate(line.len() - 2); + return Ok(line); + } + } +} + +async fn read_buffered_exact( + client: &mut C, + already_read: &[u8], + buffered_pos: &mut usize, + len: usize, + out: &mut Vec, + generation_guard: Option<&PolicyGenerationGuard>, +) -> Result<()> { + for _ in 0..len { + let byte = read_buffered_byte(client, already_read, buffered_pos, generation_guard).await?; + out.push(byte); + } + Ok(()) +} + +async fn read_buffered_byte( + client: &mut C, + already_read: &[u8], + buffered_pos: &mut usize, + generation_guard: Option<&PolicyGenerationGuard>, +) -> Result { + if *buffered_pos < already_read.len() { + let byte = already_read[*buffered_pos]; + *buffered_pos += 1; + return Ok(byte); + } + let byte = client.read_u8().await.into_diagnostic()?; + if let Some(guard) = generation_guard { + guard.ensure_current()?; + } + Ok(byte) +} + fn content_type(headers: &str) -> Option { headers.lines().skip(1).find_map(|line| { let (name, value) = line.split_once(':')?; @@ -1262,9 +1305,9 @@ fn strip_header(headers: &[u8], strip_name: &str) -> Result> { fn append_headers( headers: &[u8], add_headers: &std::collections::BTreeMap, -) -> Result> { +) -> Vec { if add_headers.is_empty() { - return Ok(headers.to_vec()); + return headers.to_vec(); } let split = headers .windows(4) @@ -1279,7 +1322,7 @@ fn append_headers( out.extend_from_slice(value.as_bytes()); } out.extend_from_slice(b"\r\n\r\n"); - Ok(out) + out } pub(crate) fn request_is_websocket_upgrade(raw_header: &[u8]) -> bool { @@ -3151,6 +3194,20 @@ mod tests { } } + #[tokio::test] + async fn collect_chunked_body_decodes_payload_bytes() { + let mut client = tokio::io::empty(); + let body = collect_chunked_body( + &mut client, + b"5\r\nhello\r\n6;ext=value\r\n world\r\n0\r\nx-checksum: abc\r\n\r\n", + None, + ) + .await + .expect("chunked body should decode"); + + assert_eq!(body, b"hello world"); + } + /// SEC-009: Bare LF in headers enables header injection. #[tokio::test] async fn reject_bare_lf_in_headers() { @@ -5257,6 +5314,38 @@ mod tests { assert!(!forwarded.contains("OPENSHELL-RESOLVE-ENV")); } + #[tokio::test] + async fn relay_request_body_rewrite_normalizes_chunked_payload() { + let (_, resolver) = SecretResolver::from_provider_env( + [("API_TOKEN".to_string(), "provider-real-token".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let alias = "provider.v1-OPENSHELL-RESOLVE-ENV-API_TOKEN"; + let raw = format!( + "POST /api/messages HTTP/1.1\r\n\ + Host: api.example.com\r\n\ + Authorization: Bearer {alias}\r\n\ + Transfer-Encoding: chunked\r\n\r\n\ + 5\r\nhello\r\n0\r\n\r\n", + ); + + let forwarded = relay_and_capture_with_options( + raw.into_bytes(), + BodyLength::Chunked, + Some(&resolver), + true, + ) + .await + .expect("relay should succeed"); + + assert!(forwarded.contains("Authorization: Bearer provider-real-token\r\n")); + assert!(forwarded.contains("Content-Length: 5\r\n")); + assert!(!forwarded.contains("Transfer-Encoding: chunked\r\n")); + assert!(forwarded.ends_with("hello")); + } + #[tokio::test] async fn relay_request_body_rewrites_percent_encoded_canonical_urlencoded_token() { let (_, resolver) = SecretResolver::from_provider_env( diff --git a/crates/openshell-supervisor-network/src/opa.rs b/crates/openshell-supervisor-network/src/opa.rs index 451c57e59..a584b414b 100644 --- a/crates/openshell-supervisor-network/src/opa.rs +++ b/crates/openshell-supervisor-network/src/opa.rs @@ -777,11 +777,7 @@ fn query_middleware_chain_locked( .map_err(|e| miette::miette!("{e}"))?; let contexts = parse_middleware_contexts(&contexts_val); let Some(context) = select_middleware_context(&contexts, request_path) else { - return Ok(global_middleware_entries( - &configs, - &input.host, - &HashSet::new(), - )?); + return global_middleware_entries(&configs, &input.host, &HashSet::new()); }; let mut explicit = Vec::new(); @@ -876,12 +872,16 @@ fn middleware_selector_matches(config: ®orus::Value, host: &str) -> bool { let Some(selector) = get_field(config, "endpoints") else { return false; }; - let includes = get_str_array(selector, "include"); - let excludes = get_str_array(selector, "exclude"); - let included = - !includes.is_empty() && includes.iter().any(|pattern| host_matches(pattern, host)); - let excluded = excludes.iter().any(|pattern| host_matches(pattern, host)); - included && !excluded + let include_patterns = get_str_array(selector, "include"); + let exclude_patterns = get_str_array(selector, "exclude"); + let matches_include = !include_patterns.is_empty() + && include_patterns + .iter() + .any(|pattern| host_matches(pattern, host)); + let matches_exclude = exclude_patterns + .iter() + .any(|pattern| host_matches(pattern, host)); + matches_include && !matches_exclude } fn host_matches(pattern: &str, host: &str) -> bool { @@ -956,7 +956,6 @@ fn regorus_value_to_prost(value: ®orus::Value) -> prost_types::Value { }) .collect(), }), - regorus::Value::Null | regorus::Value::Undefined => Kind::NullValue(0), _ => Kind::NullValue(0), }), } From 27650555f7b122bf381317638cf5e79184901e95 Mon Sep 17 00:00:00 2001 From: Piotr Mlocek Date: Fri, 26 Jun 2026 16:20:15 -0700 Subject: [PATCH 03/27] fix(supervisor-middleware): default stored policy rule fields Signed-off-by: Piotr Mlocek --- crates/openshell-server/src/grpc/policy.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/crates/openshell-server/src/grpc/policy.rs b/crates/openshell-server/src/grpc/policy.rs index 09e311bb2..ad4fdf5ba 100644 --- a/crates/openshell-server/src/grpc/policy.rs +++ b/crates/openshell-server/src/grpc/policy.rs @@ -7100,6 +7100,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; let chunk = DraftChunkRecord { id: "chunk-provider-prefix".to_string(), From e487f2c5de06e90db76c1d12449adda53abb982b Mon Sep 17 00:00:00 2001 From: Piotr Mlocek Date: Fri, 26 Jun 2026 16:58:26 -0700 Subject: [PATCH 04/27] fix(supervisor-middleware): resolve rebase policy conflicts Signed-off-by: Piotr Mlocek --- crates/openshell-policy/src/lib.rs | 4 +- .../data/sandbox-policy.rego | 35 ++-- .../openshell-supervisor-network/src/opa.rs | 183 +++++++++++++++--- proto/sandbox.proto | 2 + 4 files changed, 177 insertions(+), 47 deletions(-) diff --git a/crates/openshell-policy/src/lib.rs b/crates/openshell-policy/src/lib.rs index 0aa43c30d..46054a91d 100644 --- a/crates/openshell-policy/src/lib.rs +++ b/crates/openshell-policy/src/lib.rs @@ -19,8 +19,8 @@ use std::path::Path; use miette::{IntoDiagnostic, Result, WrapErr}; use openshell_core::proto::{ FilesystemPolicy, GraphqlOperation, L7Allow, L7DenyRule, L7QueryMatcher, L7Rule, - LandlockPolicy, MiddlewareEndpointSelector, NetworkBinary, NetworkEndpoint, - NetworkMiddlewareConfig, NetworkPolicyRule, ProcessPolicy, SandboxPolicy, McpOptions, + LandlockPolicy, McpOptions, MiddlewareEndpointSelector, NetworkBinary, NetworkEndpoint, + NetworkMiddlewareConfig, NetworkPolicyRule, ProcessPolicy, SandboxPolicy, }; use serde::{Deserialize, Serialize}; diff --git a/crates/openshell-supervisor-network/data/sandbox-policy.rego b/crates/openshell-supervisor-network/data/sandbox-policy.rego index afa4f6947..9228416e1 100644 --- a/crates/openshell-supervisor-network/data/sandbox-policy.rego +++ b/crates/openshell-supervisor-network/data/sandbox-policy.rego @@ -842,14 +842,29 @@ _policy_endpoint_configs(policy) := [ep | endpoint_has_extended_config(ep) ] -# Collect matching endpoint configs across all policies. Iterates over -# _matching_policy_names (a set, safe from regorus variable collisions) -# then collects per-policy configs via the helper function. +# Collect matching endpoint identities across all policies. Iterates over +# _matching_policy_names (a set, safe from regorus variable collisions) then +# returns the selected policy name plus endpoint index/path. Rust uses that +# identity to look up middleware attachment from policy data. +_matching_endpoint_contexts := [ctx | + some pname + _matching_policy_names[pname] + policy := data.network_policies[pname] + ep := policy.endpoints[i] + endpoint_matches_request(ep, input.network) + ctx := { + "policy": pname, + "endpoint_index": i, + "endpoint_path": object.get(ep, "path", ""), + } +] + _matching_endpoint_configs := [cfg | some pname _matching_policy_names[pname] cfgs := _policy_endpoint_configs(data.network_policies[pname]) cfg := cfgs[_] + endpoint_has_extended_config(cfg) ] matched_endpoint_config := _matching_endpoint_configs[0] if { @@ -858,20 +873,6 @@ matched_endpoint_config := _matching_endpoint_configs[0] if { network_middlewares := object.get(data, "network_middlewares", []) -_matching_middleware_contexts := [ctx | - some pname - _matching_policy_names[pname] - policy := data.network_policies[pname] - some ep - ep := policy.endpoints[_] - endpoint_matches_request(ep, input.network) - ctx := { - "policy": pname, - "policy_middleware": object.get(policy, "middleware", []), - "endpoint": ep, - } -] - _policy_has_exact_declared_endpoint(policy) if { some ep ep := policy.endpoints[_] diff --git a/crates/openshell-supervisor-network/src/opa.rs b/crates/openshell-supervisor-network/src/opa.rs index a584b414b..3d0f75bf7 100644 --- a/crates/openshell-supervisor-network/src/opa.rs +++ b/crates/openshell-supervisor-network/src/opa.rs @@ -750,9 +750,9 @@ fn network_input_json(input: &NetworkInput) -> serde_json::Value { } #[derive(Debug, Clone)] -struct MiddlewareContext { - policy_middleware: Vec, - endpoint_middleware: Vec, +struct MatchedEndpointContext { + policy_name: String, + endpoint_index: usize, endpoint_path: String, } @@ -773,19 +773,20 @@ fn query_middleware_chain_locked( return Ok(Vec::new()); } let contexts_val = engine - .eval_rule("data.openshell.sandbox._matching_middleware_contexts".into()) + .eval_rule("data.openshell.sandbox._matching_endpoint_contexts".into()) .map_err(|e| miette::miette!("{e}"))?; - let contexts = parse_middleware_contexts(&contexts_val); - let Some(context) = select_middleware_context(&contexts, request_path) else { + let contexts = parse_endpoint_contexts(&contexts_val); + let Some(context) = select_endpoint_context(&contexts, request_path)? else { return global_middleware_entries(&configs, &input.host, &HashSet::new()); }; + let policies_val = engine + .eval_rule("data.network_policies".into()) + .map_err(|e| miette::miette!("{e}"))?; + let (policy_middleware, endpoint_middleware) = + middleware_for_endpoint_identity(&policies_val, context)?; let mut explicit = Vec::new(); - for name in context - .policy_middleware - .iter() - .chain(context.endpoint_middleware.iter()) - { + for name in policy_middleware.iter().chain(endpoint_middleware.iter()) { if !explicit.contains(name) { explicit.push(name.clone()); } @@ -814,7 +815,7 @@ fn parse_middleware_configs(value: ®orus::Value) -> Result Vec { +fn parse_endpoint_contexts(value: ®orus::Value) -> Vec { let regorus::Value::Array(values) = value else { return Vec::new(); }; @@ -824,30 +825,87 @@ fn parse_middleware_contexts(value: ®orus::Value) -> Vec { let regorus::Value::Object(_) = value else { return None; }; - let endpoint = get_field(value, "endpoint")?; - Some(MiddlewareContext { - policy_middleware: get_str_array(value, "policy_middleware"), - endpoint_middleware: get_str_array(endpoint, "middleware"), - endpoint_path: get_str(endpoint, "path").unwrap_or_default(), + Some(MatchedEndpointContext { + policy_name: get_str(value, "policy").unwrap_or_default(), + endpoint_index: get_usize(value, "endpoint_index").unwrap_or_default(), + endpoint_path: get_str(value, "endpoint_path").unwrap_or_default(), }) }) .collect() } -fn select_middleware_context<'a>( - contexts: &'a [MiddlewareContext], +fn middleware_for_endpoint_identity( + policies: ®orus::Value, + context: &MatchedEndpointContext, +) -> Result<(Vec, Vec)> { + let policy = get_field(policies, &context.policy_name).ok_or_else(|| { + miette::miette!( + "matched endpoint policy '{}' was not found in OPA data", + context.policy_name + ) + })?; + let endpoint = get_array(policy, "endpoints") + .and_then(|endpoints| endpoints.get(context.endpoint_index)) + .ok_or_else(|| { + miette::miette!( + "matched endpoint {}[{}] was not found in OPA data", + context.policy_name, + context.endpoint_index + ) + })?; + Ok(( + get_str_array(policy, "middleware"), + get_str_array(endpoint, "middleware"), + )) +} + +fn select_endpoint_context<'a>( + contexts: &'a [MatchedEndpointContext], request_path: &str, -) -> Option<&'a MiddlewareContext> { - contexts +) -> Result> { + let matching: Vec<_> = contexts .iter() .filter(|context| crate::l7::endpoint_path_matches(&context.endpoint_path, request_path)) - .max_by_key(|context| { - if context.endpoint_path.is_empty() { - 0 - } else { - context.endpoint_path.chars().filter(|c| *c != '*').count() - } - }) + .map(|context| (endpoint_path_specificity(&context.endpoint_path), context)) + .collect(); + let Some(max_specificity) = matching.iter().map(|(specificity, _)| *specificity).max() else { + return Ok(None); + }; + let best: Vec<_> = matching + .into_iter() + .filter(|(specificity, _)| *specificity == max_specificity) + .map(|(_, context)| context) + .collect(); + if best.len() > 1 { + let matches = best + .iter() + .map(|context| { + format!( + "{}[{}] path={}", + context.policy_name, + context.endpoint_index, + if context.endpoint_path.is_empty() { + "" + } else { + context.endpoint_path.as_str() + } + ) + }) + .collect::>() + .join(", "); + return Err(miette::miette!( + "ambiguous middleware endpoint match for request path '{request_path}': {matches}" + )); + } + Ok(best.into_iter().next()) +} + +fn endpoint_path_specificity(path: &str) -> usize { + if path.is_empty() { + 0 + } else { + path.chars().filter(|c| *c != '*').count() + } } fn global_middleware_entries( @@ -918,6 +976,25 @@ fn get_field<'a>(val: &'a regorus::Value, key: &str) -> Option<&'a regorus::Valu } } +fn get_array<'a>(val: &'a regorus::Value, key: &str) -> Option<&'a [regorus::Value]> { + let regorus::Value::Array(values) = get_field(val, key)? else { + return None; + }; + Some(values) +} + +fn get_usize(val: ®orus::Value, key: &str) -> Option { + let value = get_field(val, key)?; + let regorus::Value::Number(number) = value else { + return None; + }; + let value = number.as_f64()?; + if !value.is_finite() || value.fract() != 0.0 || value < 0.0 { + return None; + } + format!("{value:.0}").parse::().ok() +} + fn regorus_value_to_struct(value: ®orus::Value) -> prost_types::Struct { let regorus::Value::Object(map) = value else { return prost_types::Struct::default(); @@ -3305,6 +3382,7 @@ network_policies: path: "/usr/bin/curl".to_string(), ..Default::default() }], + middleware: vec![], }, ); @@ -3323,6 +3401,7 @@ network_policies: run_as_group: "sandbox".to_string(), }), network_policies, + network_middlewares: vec![], }; let engine = OpaEngine::from_proto(&proto).expect("engine from proto"); @@ -3377,6 +3456,7 @@ network_policies: path: "/usr/bin/curl".to_string(), ..Default::default() }], + middleware: vec![], }, ); @@ -3395,6 +3475,7 @@ network_policies: run_as_group: "sandbox".to_string(), }), network_policies, + network_middlewares: vec![], }; let engine = OpaEngine::from_proto(&proto).expect("engine from proto"); @@ -6955,6 +7036,52 @@ network_policies: ); } + #[test] + fn middleware_chain_rejects_ambiguous_duplicate_endpoint_identity() { + let data = r#" +network_middlewares: + - name: first-redactor + middleware: openshell/secrets + - name: second-redactor + middleware: openshell/secrets +network_policies: + api: + name: api + endpoints: + - host: api.example.com + port: 443 + protocol: rest + enforcement: enforce + middleware: ["first-redactor"] + access: full + - host: api.example.com + port: 443 + protocol: rest + enforcement: enforce + middleware: ["second-redactor"] + access: full + binaries: + - { path: /usr/bin/curl } +"#; + let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); + let input = NetworkInput { + host: "api.example.com".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/curl"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let err = engine + .query_middleware_chain_with_generation(&input, "/v1/messages") + .expect_err("equivalent endpoint identities should be ambiguous"); + assert!( + err.to_string() + .contains("ambiguous middleware endpoint match"), + "{err:?}" + ); + } + #[test] fn middleware_policy_validation_rejects_bad_configs() { let cases = [ diff --git a/proto/sandbox.proto b/proto/sandbox.proto index 5d2bc31a5..a73d762e5 100644 --- a/proto/sandbox.proto +++ b/proto/sandbox.proto @@ -202,6 +202,8 @@ message McpOptions { // MCP-family methods at the method layer unless a tool-name policy narrows // tools/call. When unset or false, explicit method rules are required. optional bool allow_all_known_mcp_methods = 2; +} + // Trusted GraphQL operation classification. message GraphqlOperation { // Operation type: "query", "mutation", or "subscription". From 9f892ef6cf32eec74c52019e7cf1b3fd03c37b4f Mon Sep 17 00:00:00 2001 From: Piotr Mlocek Date: Mon, 29 Jun 2026 10:15:03 -0700 Subject: [PATCH 05/27] feat(supervisor-middleware): implement phase one runtime Signed-off-by: Piotr Mlocek --- Cargo.lock | 1 + crates/openshell-policy/Cargo.toml | 1 + crates/openshell-policy/src/lib.rs | 47 +++++ .../src/lib.rs | 88 +++++++-- .../src/service.rs | 9 +- .../data/sandbox-policy.rego | 2 + .../src/l7/relay.rs | 164 ++++++++++++++++- .../src/l7/rest.rs | 63 +++++++ .../openshell-supervisor-network/src/opa.rs | 174 ++++++++++++++---- .../openshell-supervisor-network/src/proxy.rs | 67 +++++++ 10 files changed, 551 insertions(+), 65 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 3ed95de90..c083f7a8d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3771,6 +3771,7 @@ version = "0.0.0" dependencies = [ "miette", "openshell-core", + "openshell-supervisor-middleware", "prost-types", "serde", "serde_json", diff --git a/crates/openshell-policy/Cargo.toml b/crates/openshell-policy/Cargo.toml index 50bea5b32..7ccd5d967 100644 --- a/crates/openshell-policy/Cargo.toml +++ b/crates/openshell-policy/Cargo.toml @@ -12,6 +12,7 @@ repository.workspace = true [dependencies] openshell-core = { path = "../openshell-core", default-features = false } +openshell-supervisor-middleware = { path = "../openshell-supervisor-middleware" } prost-types = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } diff --git a/crates/openshell-policy/src/lib.rs b/crates/openshell-policy/src/lib.rs index 46054a91d..46646f755 100644 --- a/crates/openshell-policy/src/lib.rs +++ b/crates/openshell-policy/src/lib.rs @@ -1253,6 +1253,8 @@ pub enum PolicyViolation { }, /// `credential_signing` and `request_body_credential_rewrite` are both set. CredentialSigningWithBodyRewrite { policy_name: String, host: String }, + /// A built-in middleware configuration is invalid. + InvalidBuiltinMiddlewareConfig { name: String, reason: String }, } impl fmt::Display for PolicyViolation { @@ -1317,6 +1319,9 @@ impl fmt::Display for PolicyViolation { and request_body_credential_rewrite set; these options are mutually exclusive" ) } + Self::InvalidBuiltinMiddlewareConfig { name, reason } => { + write!(f, "middleware config '{name}' is invalid: {reason}") + } } } } @@ -1449,6 +1454,21 @@ pub fn validate_sandbox_policy( } } + for middleware in &policy.network_middlewares { + if middleware.middleware.starts_with("openshell/") { + let config = middleware.config.as_ref().cloned().unwrap_or_default(); + if let Err(error) = openshell_supervisor_middleware::validate_builtin_config( + &middleware.middleware, + &config, + ) { + violations.push(PolicyViolation::InvalidBuiltinMiddlewareConfig { + name: middleware.name.clone(), + reason: error.to_string(), + }); + } + } + } + if violations.is_empty() { Ok(()) } else { @@ -1884,6 +1904,33 @@ network_policies: assert_eq!(violations.len(), 2); } + #[test] + fn validate_rejects_invalid_builtin_middleware_config() { + let mut policy = restrictive_default_policy(); + policy.network_middlewares.push(NetworkMiddlewareConfig { + name: "redact-secrets".into(), + middleware: "openshell/secrets".into(), + config: Some(prost_types::Struct { + fields: std::iter::once(( + "secrets".into(), + prost_types::Value { + kind: Some(prost_types::value::Kind::StringValue("allow".into())), + }, + )) + .collect(), + }), + on_error: String::new(), + endpoints: None, + }); + + let violations = validate_sandbox_policy(&policy).expect_err("invalid config"); + assert!(violations.iter().any(|violation| matches!( + violation, + PolicyViolation::InvalidBuiltinMiddlewareConfig { name, .. } + if name == "redact-secrets" + ))); + } + #[test] fn validate_rejects_non_sandbox_user() { let mut policy = restrictive_default_policy(); diff --git a/crates/openshell-supervisor-middleware/src/lib.rs b/crates/openshell-supervisor-middleware/src/lib.rs index b68d83c86..4ec7e2782 100644 --- a/crates/openshell-supervisor-middleware/src/lib.rs +++ b/crates/openshell-supervisor-middleware/src/lib.rs @@ -14,7 +14,7 @@ pub use service::InProcessMiddlewareService; use openshell_core::proto::middleware::v1::supervisor_middleware_server::SupervisorMiddleware; use openshell_core::proto::{ - Decision, Finding, HttpRequestEvaluation, HttpRequestTarget, NetworkMiddlewareConfig, Process, + Decision, Finding, HttpRequestEvaluation, HttpRequestTarget, NetworkMiddlewareConfig, RequestContext, }; use tonic::Request; @@ -24,6 +24,19 @@ pub const HTTP_REQUEST_OPERATION: &str = "HttpRequest"; pub const PRE_CREDENTIALS_PHASE: &str = "pre_credentials"; pub const BUILTIN_SECRETS: &str = "openshell/secrets"; +/// Validate the configuration for an in-process middleware implementation. +/// +/// Policy admission uses this same implementation-specific validation before a +/// configuration can reach the request path. +pub fn validate_builtin_config(implementation: &str, config: &prost_types::Struct) -> Result<()> { + match implementation { + BUILTIN_SECRETS => builtins::secrets::validate_config(config), + other => Err(miette!( + "middleware implementation '{other}' is not available in phase 1" + )), + } +} + #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum OnError { FailClosed, @@ -76,9 +89,6 @@ impl TryFrom<&NetworkMiddlewareConfig> for ChainEntry { pub struct HttpRequestInput { pub request_id: String, pub sandbox_id: String, - pub binary: String, - pub pid: u32, - pub ancestors: Vec, pub scheme: String, pub host: String, pub port: u16, @@ -210,6 +220,26 @@ impl ChainRunner { } }; + let decision = match Decision::try_from(result.decision) { + Ok(decision @ (Decision::Allow | Decision::Deny)) => decision, + Ok(Decision::Unspecified) | Err(_) => { + match apply_on_error(entry, "invalid_response_decision", &mut applied) { + OnErrorAction::FailOpen => continue, + OnErrorAction::FailClosed(reason) => { + return Ok(ChainOutcome { + allowed: false, + reason, + body, + added_headers, + findings, + metadata, + applied, + }); + } + } + } + }; + // A result proposing unsafe header mutations is a malformed response: // route it through `on_error` instead of applying any of it. if validate_header_mutations(&headers, &result.add_headers).is_err() { @@ -251,11 +281,11 @@ impl ChainRunner { applied.push(MiddlewareInvocation { name: entry.name.clone(), implementation: entry.implementation.clone(), - decision: Decision::try_from(result.decision).unwrap_or(Decision::Unspecified), + decision, transformed, failed: false, }); - if result.decision == Decision::Deny as i32 { + if decision == Decision::Deny { return Ok(ChainOutcome { allowed: false, reason: safe_reason(&result.reason), @@ -293,11 +323,7 @@ fn build_evaluation( context: Some(RequestContext { request_id: input.request_id.clone(), sandbox_id: input.sandbox_id.clone(), - originating_process: Some(Process { - binary: input.binary.clone(), - pid: input.pid, - ancestors: input.ancestors.clone(), - }), + originating_process: None, }), config: Some(entry.config.clone()), target: Some(HttpRequestTarget { @@ -399,9 +425,6 @@ mod tests { HttpRequestInput { request_id: "req".into(), sandbox_id: "sbx".into(), - binary: "/usr/bin/curl".into(), - pid: 42, - ancestors: vec![], scheme: "https".into(), host: "api.example.com".into(), port: 443, @@ -413,6 +436,21 @@ mod tests { } } + #[test] + fn phase_one_evaluation_omits_originating_process() { + let entry = entry("redact", OnError::FailClosed); + let input = input("payload"); + let evaluation = build_evaluation(&entry, &input, &BTreeMap::new(), b"payload"); + + assert!( + evaluation + .context + .expect("request context") + .originating_process + .is_none() + ); + } + #[tokio::test] async fn redacts_common_secret_patterns() { let outcome = ChainRunner::default() @@ -684,4 +722,26 @@ mod tests { assert_eq!(outcome.applied.len(), 1); assert!(outcome.applied[0].failed); } + + #[tokio::test] + async fn unspecified_decision_uses_fail_closed() { + let runner = ChainRunner::new(Arc::new(ScriptedService { + result: openshell_core::proto::HttpRequestResult { + decision: Decision::Unspecified as i32, + ..allow_result() + }, + })); + + let outcome = runner + .evaluate(&[entry("redact", OnError::FailClosed)], input("hello")) + .await + .expect("evaluate"); + + assert!(!outcome.allowed); + assert_eq!( + outcome.reason, + "middleware_failed: invalid_response_decision" + ); + assert!(outcome.applied[0].failed); + } } diff --git a/crates/openshell-supervisor-middleware/src/service.rs b/crates/openshell-supervisor-middleware/src/service.rs index 31cca5694..cbd9231cd 100644 --- a/crates/openshell-supervisor-middleware/src/service.rs +++ b/crates/openshell-supervisor-middleware/src/service.rs @@ -10,7 +10,7 @@ use tonic::{Request, Response, Status}; use crate::{ API_VERSION, BUILTIN_SECRETS, HTTP_REQUEST_OPERATION, PRE_CREDENTIALS_PHASE, builtins, - safe_reason, + safe_reason, validate_builtin_config, }; #[derive(Debug, Default)] @@ -40,12 +40,7 @@ impl SupervisorMiddleware for InProcessMiddlewareService { ) -> Result, Status> { let request = request.into_inner(); let config = request.config.unwrap_or_default(); - let validation = match request.binding_id.as_str() { - BUILTIN_SECRETS => builtins::secrets::validate_config(&config), - other => Err(miette::miette!( - "middleware implementation '{other}' is not available in phase 1" - )), - }; + let validation = validate_builtin_config(&request.binding_id, &config); Ok(Response::new(match validation { Ok(()) => ValidateConfigResponse { valid: true, diff --git a/crates/openshell-supervisor-network/data/sandbox-policy.rego b/crates/openshell-supervisor-network/data/sandbox-policy.rego index 9228416e1..52f6f1046 100644 --- a/crates/openshell-supervisor-network/data/sandbox-policy.rego +++ b/crates/openshell-supervisor-network/data/sandbox-policy.rego @@ -871,6 +871,8 @@ matched_endpoint_config := _matching_endpoint_configs[0] if { count(_matching_endpoint_configs) > 0 } +network_policies := object.get(data, "network_policies", {}) + network_middlewares := object.get(data, "network_middlewares", []) _policy_has_exact_declared_endpoint(policy) if { diff --git a/crates/openshell-supervisor-network/src/l7/relay.rs b/crates/openshell-supervisor-network/src/l7/relay.rs index 6e5c3c4e9..c773fdcf4 100644 --- a/crates/openshell-supervisor-network/src/l7/relay.rs +++ b/crates/openshell-supervisor-network/src/l7/relay.rs @@ -768,12 +768,12 @@ fn jsonrpc_engine_type(protocol: L7Protocol) -> &'static str { } } -enum MiddlewareApplyResult { +pub(crate) enum MiddlewareApplyResult { Allowed(crate::l7::provider::L7Request), Denied(String), } -async fn apply_middleware_chain( +pub(crate) async fn apply_middleware_chain( req: crate::l7::provider::L7Request, client: &mut C, ctx: &L7EvalContext, @@ -796,18 +796,16 @@ async fn apply_middleware_chain( } }; let headers = safe_middleware_headers(&buffered.headers)?; + let query = raw_query_from_request_headers(&buffered.headers)?; let input = openshell_supervisor_middleware::HttpRequestInput { request_id: uuid::Uuid::new_v4().to_string(), - sandbox_id: String::new(), - binary: ctx.binary_path.clone(), - pid: 0, - ancestors: ctx.ancestors.clone(), + sandbox_id: openshell_ocsf::ctx::ctx().sandbox_id.clone(), scheme: "https".into(), host: ctx.host.clone(), port: ctx.port, method: req.action.clone(), path: req.target.clone(), - query: String::new(), + query, headers, body: buffered.body, }; @@ -828,6 +826,19 @@ async fn apply_middleware_chain( } } +fn raw_query_from_request_headers(headers: &[u8]) -> Result { + let header_str = + std::str::from_utf8(headers).map_err(|_| miette!("HTTP headers contain invalid UTF-8"))?; + let target = header_str + .lines() + .next() + .and_then(|line| line.split_whitespace().nth(1)) + .ok_or_else(|| miette!("HTTP request line is missing a target"))?; + Ok(target + .split_once('?') + .map_or_else(String::new, |(_, query)| query.to_string())) +} + /// Apply the chain's `on_error` policy when the request body cannot be buffered /// for inspection because it exceeds the size cap. The RFC treats an unbufferable /// body as an `on_error` event: it is denied unless every attached middleware is @@ -1446,6 +1457,37 @@ where } if allowed || (config.enforcement == EnforcementMode::Audit && !force_deny) { + let chain = + engine.query_middleware_chain(&middleware_network_input(ctx), &req.target)?; + let req = + match apply_middleware_chain(req, client, ctx, chain, engine.generation_guard()) + .await? + { + MiddlewareApplyResult::Allowed(req) => req, + MiddlewareApplyResult::Denied(reason) => { + crate::l7::rest::RestProvider::default() + .deny_with_redacted_target( + &crate::l7::provider::L7Request { + action: request_info.action.clone(), + target: redacted_target.clone(), + query_params: request_info.query_params.clone(), + raw_header: Vec::new(), + body_length: crate::l7::provider::BodyLength::None, + }, + &ctx.policy_name, + &reason, + client, + Some(&redacted_target), + Some(crate::l7::rest::DenyResponseContext { + host: Some(&ctx.host), + port: Some(ctx.port), + binary: Some(&ctx.binary_path), + }), + ) + .await?; + return Ok(()); + } + }; // Future MCP response/SSE introspection or rewrite would hook here // before returning upstream bytes. The current policy schema has no // trusted-annotations or version-profile field, so MCP responses and @@ -2736,6 +2778,104 @@ network_policies: .unwrap(); } + #[tokio::test] + async fn jsonrpc_middleware_fail_closed_does_not_reach_upstream() { + let data = r#" +network_middlewares: + - name: request-middleware + middleware: example/unavailable + on_error: fail_closed +network_policies: + jsonrpc_api: + name: jsonrpc_api + middleware: ["request-middleware"] + endpoints: + - host: api.example.test + port: 443 + protocol: json-rpc + enforcement: enforce + rules: + - allow: + method: reports.list + binaries: + - { path: /usr/bin/node } +"#; + let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); + let input = NetworkInput { + host: "api.example.test".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/node"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let (endpoint_config, generation) = engine + .query_endpoint_config_with_generation(&input) + .expect("endpoint config"); + let config = crate::l7::parse_l7_config(&endpoint_config.expect("json-rpc config")) + .expect("parse JSON-RPC config"); + let tunnel_engine = engine.clone_engine_for_tunnel(generation).unwrap(); + let ctx = L7EvalContext { + host: "api.example.test".into(), + port: 443, + policy_name: "jsonrpc_api".into(), + binary_path: "/usr/bin/node".into(), + ancestors: vec![], + cmdline_paths: vec![], + secret_resolver: None, + activity_tx: None, + dynamic_credentials: None, + token_grant_resolver: None, + }; + let (mut app, mut relay_client) = tokio::io::duplex(8192); + let (mut relay_upstream, mut upstream) = tokio::io::duplex(8192); + let relay = tokio::spawn(async move { + relay_jsonrpc( + &config, + &tunnel_engine, + &mut relay_client, + &mut relay_upstream, + &ctx, + ) + .await + }); + + let body = br#"{"jsonrpc":"2.0","id":1,"method":"reports.list"}"#; + let request = format!( + "POST /rpc HTTP/1.1\r\nHost: api.example.test\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}", + body.len(), + std::str::from_utf8(body).unwrap() + ); + app.write_all(request.as_bytes()).await.unwrap(); + + let mut response = [0u8; 512]; + let n = tokio::time::timeout(std::time::Duration::from_secs(1), app.read(&mut response)) + .await + .expect("denial should reach client") + .unwrap(); + let response = String::from_utf8_lossy(&response[..n]); + assert!(response.contains("403 Forbidden")); + assert!(response.contains("middleware_failed")); + + let mut upstream_request = [0u8; 32]; + let result = tokio::time::timeout( + std::time::Duration::from_millis(100), + upstream.read(&mut upstream_request), + ) + .await; + assert!( + matches!(result, Err(_) | Ok(Ok(0))), + "upstream should not receive request bytes" + ); + + drop(app); + tokio::time::timeout(std::time::Duration::from_secs(1), relay) + .await + .expect("relay should finish") + .unwrap() + .unwrap(); + } + #[tokio::test] async fn l7_rest_middleware_over_capacity_fails_closed() { let (config, tunnel_engine, ctx) = @@ -2842,6 +2982,16 @@ network_policies: )); } + #[test] + fn middleware_keeps_the_raw_request_query() { + let query = raw_query_from_request_headers( + b"POST /v1/messages?token=a%2Bb&scope=private HTTP/1.1\r\nHost: api.example.test\r\n\r\n", + ) + .expect("query from request headers"); + + assert_eq!(query, "token=a%2Bb&scope=private"); + } + /// Tracing layer that captures emitted `OcsfEvent`s for assertions. struct OcsfCaptureLayer(Arc>>); diff --git a/crates/openshell-supervisor-network/src/l7/rest.rs b/crates/openshell-supervisor-network/src/l7/rest.rs index 2c85cacf6..19f73e2ad 100644 --- a/crates/openshell-supervisor-network/src/l7/rest.rs +++ b/crates/openshell-supervisor-network/src/l7/rest.rs @@ -246,6 +246,36 @@ async fn parse_http_request( })) } +/// Build an L7 request from a request already buffered by another proxy path. +/// +/// The forward proxy needs this after it has consumed the incoming HTTP/1 +/// headers itself. Keep the framing and query parsing here so it matches the +/// stream-based REST parser rather than growing another local parser. +pub(crate) fn request_from_buffered_http( + action: impl Into, + target: impl Into, + query_target: &str, + raw_header: Vec, +) -> Result { + let header_end = raw_header + .windows(4) + .position(|window| window == b"\r\n\r\n") + .ok_or_else(|| miette!("HTTP request headers are missing the CRLF terminator"))? + + 4; + let header_str = std::str::from_utf8(&raw_header[..header_end]) + .map_err(|_| miette!("HTTP headers contain invalid UTF-8"))?; + let body_length = parse_body_length(header_str)?; + let (_, query_params) = parse_target_query(query_target)?; + + Ok(L7Request { + action: action.into(), + target: target.into(), + query_params, + raw_header, + body_length, + }) +} + /// Rebuild the request line in a raw HTTP header block with a canonicalized /// target. Called when the canonical path differs from what the client sent, /// so the upstream dispatches on the exact bytes the policy engine evaluated. @@ -3015,6 +3045,39 @@ mod tests { } } + #[test] + fn buffered_request_parser_uses_shared_framing_and_query_parsing() { + let request = request_from_buffered_http( + "POST", + "/v1/items", + "/v1/items?tag=first&tag=second", + b"POST /v1/items?tag=first&tag=second HTTP/1.1\r\nHost: api.example.com\r\nContent-Length: 3\r\n\r\nabc" + .to_vec(), + ) + .expect("parse buffered request"); + + assert_eq!(request.action, "POST"); + assert_eq!(request.target, "/v1/items"); + assert_eq!( + request.query_params.get("tag"), + Some(&vec!["first".to_string(), "second".to_string()]) + ); + assert!(matches!(request.body_length, BodyLength::ContentLength(3))); + } + + #[test] + fn buffered_request_parser_rejects_missing_header_terminator() { + let err = request_from_buffered_http( + "GET", + "/v1/items", + "/v1/items", + b"GET /v1/items HTTP/1.1\r\nHost: api.example.com\r\n".to_vec(), + ) + .expect_err("unterminated headers must be rejected"); + + assert!(err.to_string().contains("missing the CRLF terminator")); + } + #[test] fn parse_chunked() { let headers = diff --git a/crates/openshell-supervisor-network/src/opa.rs b/crates/openshell-supervisor-network/src/opa.rs index 3d0f75bf7..3efec0212 100644 --- a/crates/openshell-supervisor-network/src/opa.rs +++ b/crates/openshell-supervisor-network/src/opa.rs @@ -776,12 +776,12 @@ fn query_middleware_chain_locked( .eval_rule("data.openshell.sandbox._matching_endpoint_contexts".into()) .map_err(|e| miette::miette!("{e}"))?; let contexts = parse_endpoint_contexts(&contexts_val); - let Some(context) = select_endpoint_context(&contexts, request_path)? else { - return global_middleware_entries(&configs, &input.host, &HashSet::new()); - }; let policies_val = engine - .eval_rule("data.network_policies".into()) + .eval_rule("data.openshell.sandbox.network_policies".into()) .map_err(|e| miette::miette!("{e}"))?; + let Some(context) = select_endpoint_context(&contexts, request_path, &policies_val)? else { + return global_middleware_entries(&configs, &input.host, &HashSet::new()); + }; let (policy_middleware, endpoint_middleware) = middleware_for_endpoint_identity(&policies_val, context)?; @@ -862,6 +862,7 @@ fn middleware_for_endpoint_identity( fn select_endpoint_context<'a>( contexts: &'a [MatchedEndpointContext], request_path: &str, + policies: ®orus::Value, ) -> Result> { let matching: Vec<_> = contexts .iter() @@ -876,30 +877,56 @@ fn select_endpoint_context<'a>( .filter(|(specificity, _)| *specificity == max_specificity) .map(|(_, context)| context) .collect(); - if best.len() > 1 { - let matches = best - .iter() - .map(|context| { - format!( - "{}[{}] path={}", - context.policy_name, - context.endpoint_index, - if context.endpoint_path.is_empty() { - "" - } else { - context.endpoint_path.as_str() - } - ) - }) - .collect::>() - .join(", "); - return Err(miette::miette!( - "ambiguous middleware endpoint match for request path '{request_path}': {matches}" - )); + if let Some((first, rest)) = best.split_first() { + let first_middleware = explicit_middleware_for_endpoint_identity(policies, first)?; + for context in rest { + if explicit_middleware_for_endpoint_identity(policies, context)? != first_middleware { + let matches = best + .iter() + .map(|context| { + format!( + "{}[{}] path={}", + context.policy_name, + context.endpoint_index, + if context.endpoint_path.is_empty() { + "" + } else { + context.endpoint_path.as_str() + } + ) + }) + .collect::>() + .join(", "); + return Err(miette::miette!( + "ambiguous middleware endpoint match for request path '{request_path}': {matches}" + )); + } + } } Ok(best.into_iter().next()) } +fn explicit_middleware_for_endpoint_identity( + policies: ®orus::Value, + context: &MatchedEndpointContext, +) -> Result> { + let (policy_middleware, endpoint_middleware) = + middleware_for_endpoint_identity(policies, context)?; + Ok(dedup_middleware_names( + policy_middleware.iter().chain(endpoint_middleware.iter()), + )) +} + +fn dedup_middleware_names<'a>(names: impl IntoIterator) -> Vec { + let mut deduped = Vec::new(); + for name in names { + if !deduped.contains(name) { + deduped.push(name.clone()); + } + } + deduped +} + fn endpoint_path_specificity(path: &str) -> usize { if path.is_empty() { 0 @@ -1402,10 +1429,91 @@ fn validate_middleware_policies(data: &serde_json::Value) -> Vec { )); } } + validate_ambiguous_middleware_endpoints( + policy_name, + policy, + &policy_middleware, + &mut errors, + ); } errors } +fn validate_ambiguous_middleware_endpoints( + policy_name: &str, + policy: &serde_json::Value, + policy_middleware: &[String], + errors: &mut Vec, +) { + let endpoints = policy + .get("endpoints") + .and_then(serde_json::Value::as_array) + .map_or(&[][..], Vec::as_slice); + let mut seen: Vec<(usize, MiddlewareEndpointKey, Vec)> = Vec::new(); + for (index, endpoint) in endpoints.iter().enumerate() { + let key = middleware_endpoint_key(endpoint); + let endpoint_middleware = json_string_array(endpoint.get("middleware")); + let chain = + dedup_middleware_names(policy_middleware.iter().chain(endpoint_middleware.iter())); + for (previous_index, previous_key, previous_chain) in &seen { + if previous_key == &key && previous_chain != &chain { + errors.push(format!( + "network policy '{policy_name}' endpoints[{previous_index}] and endpoints[{index}] have equivalent middleware selection keys ({key}) but different middleware chains" + )); + } + } + seen.push((index, key, chain)); + } +} + +#[derive(Debug, PartialEq, Eq)] +struct MiddlewareEndpointKey { + host: String, + ports: Vec, + path: String, +} + +impl std::fmt::Display for MiddlewareEndpointKey { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "host={} ports={:?} path={}", + if self.host.is_empty() { + "" + } else { + self.host.as_str() + }, + self.ports, + if self.path.is_empty() { + "" + } else { + self.path.as_str() + } + ) + } +} + +fn middleware_endpoint_key(endpoint: &serde_json::Value) -> MiddlewareEndpointKey { + let host = endpoint + .get("host") + .and_then(serde_json::Value::as_str) + .unwrap_or_default() + .to_ascii_lowercase(); + let mut ports: Vec = endpoint + .get("ports") + .and_then(serde_json::Value::as_array) + .map(|ports| ports.iter().filter_map(serde_json::Value::as_u64).collect()) + .unwrap_or_default(); + ports.sort_unstable(); + ports.dedup(); + let path = endpoint + .get("path") + .and_then(serde_json::Value::as_str) + .unwrap_or_default() + .to_string(); + MiddlewareEndpointKey { host, ports, path } +} + fn json_string_array(value: Option<&serde_json::Value>) -> Vec { value .and_then(serde_json::Value::as_array) @@ -7037,7 +7145,7 @@ network_policies: } #[test] - fn middleware_chain_rejects_ambiguous_duplicate_endpoint_identity() { + fn middleware_validation_rejects_ambiguous_duplicate_endpoint_middleware() { let data = r#" network_middlewares: - name: first-redactor @@ -7063,21 +7171,13 @@ network_policies: binaries: - { path: /usr/bin/curl } "#; - let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); - let input = NetworkInput { - host: "api.example.com".into(), - port: 443, - binary_path: PathBuf::from("/usr/bin/curl"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], + let err = match OpaEngine::from_strings(TEST_POLICY, data) { + Ok(_) => panic!("equivalent endpoints with different middleware should be invalid"), + Err(err) => err, }; - let err = engine - .query_middleware_chain_with_generation(&input, "/v1/messages") - .expect_err("equivalent endpoint identities should be ambiguous"); assert!( err.to_string() - .contains("ambiguous middleware endpoint match"), + .contains("equivalent middleware selection keys"), "{err:?}" ); } diff --git a/crates/openshell-supervisor-network/src/proxy.rs b/crates/openshell-supervisor-network/src/proxy.rs index af3331735..f8310fbdc 100644 --- a/crates/openshell-supervisor-network/src/proxy.rs +++ b/crates/openshell-supervisor-network/src/proxy.rs @@ -4184,6 +4184,73 @@ async fn handle_forward_proxy( } emit_forward_success_activity(activity_tx, l7_activity_pending); + let middleware_path = path.split_once('?').map_or(path.as_str(), |(path, _)| path); + let middleware_input = crate::opa::NetworkInput { + host: host_lc.clone(), + port, + binary_path: decision.binary.clone().unwrap_or_default(), + binary_sha256: String::new(), + ancestors: decision.ancestors.clone(), + cmdline_paths: decision.cmdline_paths.clone(), + }; + let (chain, generation) = + opa_engine.query_middleware_chain_with_generation(&middleware_input, middleware_path)?; + if generation != forward_generation_guard.captured_generation() { + emit_l7_tunnel_close_after_policy_change( + &host_lc, + port, + miette::miette!( + "policy changed before forward middleware evaluation [expected_generation:{} current_generation:{}]", + forward_generation_guard.captured_generation(), + generation, + ), + ); + respond( + client, + &build_json_error_response( + 403, + "Forbidden", + "policy_denied", + &format!("{method} {host_lc}:{port}{path} not permitted by policy"), + ), + ) + .await?; + return Ok(()); + } + if !chain.is_empty() { + let request = crate::l7::rest::request_from_buffered_http( + method, + middleware_path, + &upstream_target, + forward_request_bytes, + )?; + forward_request_bytes = match crate::l7::relay::apply_middleware_chain( + request, + client, + &l7_ctx, + chain, + &forward_generation_guard, + ) + .await? + { + crate::l7::relay::MiddlewareApplyResult::Allowed(request) => request.raw_header, + crate::l7::relay::MiddlewareApplyResult::Denied(reason) => { + emit_activity_simple(activity_tx, true, "middleware"); + respond( + client, + &build_json_error_response( + 403, + "Forbidden", + "middleware_denied", + &format!("{method} {host_lc}:{port}{path} denied by middleware: {reason}"), + ), + ) + .await?; + return Ok(()); + } + }; + } + forward_request_bytes = match inject_token_grant_for_forward_request( method, &upstream_target, From da486b2f3e11621190dbf718fe66ba383b853279 Mon Sep 17 00:00:00 2001 From: Piotr Mlocek Date: Mon, 29 Jun 2026 14:10:40 -0700 Subject: [PATCH 06/27] fix(supervisor-middleware): harden selection and buffering Signed-off-by: Piotr Mlocek --- Cargo.lock | 1 + crates/openshell-cli/src/policy_update.rs | 1 - crates/openshell-policy/Cargo.toml | 1 + crates/openshell-policy/src/compose.rs | 1 - crates/openshell-policy/src/lib.rs | 373 ++++++++++-- crates/openshell-policy/src/merge.rs | 24 - crates/openshell-providers/src/profiles.rs | 2 - .../src/mechanistic_mapper.rs | 1 - crates/openshell-server/src/grpc/policy.rs | 32 -- .../openshell-server/src/grpc/validation.rs | 22 + .../Cargo.toml | 4 +- .../src/builtins/mod.rs | 25 + .../src/builtins/secrets.rs | 22 +- .../src/lib.rs | 383 +++++++++++-- .../src/service.rs | 24 +- .../data/sandbox-policy.rego | 19 - .../src/l7/relay.rs | 151 +++-- .../src/l7/rest.rs | 138 +++-- .../openshell-supervisor-network/src/opa.rs | 534 ++++-------------- .../src/policy_local.rs | 5 - .../openshell-supervisor-network/src/proxy.rs | 5 +- proto/middleware.proto | 2 + proto/sandbox.proto | 8 +- 23 files changed, 1073 insertions(+), 705 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c083f7a8d..0634d12e9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3769,6 +3769,7 @@ dependencies = [ name = "openshell-policy" version = "0.0.0" dependencies = [ + "glob", "miette", "openshell-core", "openshell-supervisor-middleware", diff --git a/crates/openshell-cli/src/policy_update.rs b/crates/openshell-cli/src/policy_update.rs index 824b1dde0..1f1f64750 100644 --- a/crates/openshell-cli/src/policy_update.rs +++ b/crates/openshell-cli/src/policy_update.rs @@ -65,7 +65,6 @@ pub fn build_policy_update_plan( ..Default::default() }) .collect(), - middleware: Vec::new(), }; merge_operations.push(PolicyMergeOperation { operation: Some(policy_merge_operation::Operation::AddRule(AddNetworkRule { diff --git a/crates/openshell-policy/Cargo.toml b/crates/openshell-policy/Cargo.toml index 7ccd5d967..073728db1 100644 --- a/crates/openshell-policy/Cargo.toml +++ b/crates/openshell-policy/Cargo.toml @@ -11,6 +11,7 @@ license.workspace = true repository.workspace = true [dependencies] +glob = { workspace = true } openshell-core = { path = "../openshell-core", default-features = false } openshell-supervisor-middleware = { path = "../openshell-supervisor-middleware" } prost-types = { workspace = true } diff --git a/crates/openshell-policy/src/compose.rs b/crates/openshell-policy/src/compose.rs index 1ad0d4617..7ca8584d9 100644 --- a/crates/openshell-policy/src/compose.rs +++ b/crates/openshell-policy/src/compose.rs @@ -115,7 +115,6 @@ mod tests { ..Default::default() }], binaries: Vec::new(), - middleware: Vec::new(), } } diff --git a/crates/openshell-policy/src/lib.rs b/crates/openshell-policy/src/lib.rs index 46646f755..29585a623 100644 --- a/crates/openshell-policy/src/lib.rs +++ b/crates/openshell-policy/src/lib.rs @@ -12,7 +12,7 @@ mod compose; mod merge; -use std::collections::{BTreeMap, HashMap}; +use std::collections::{BTreeMap, HashMap, HashSet}; use std::fmt; use std::path::Path; @@ -89,8 +89,6 @@ struct NetworkPolicyRuleDef { endpoints: Vec, #[serde(default, skip_serializing_if = "Vec::is_empty")] binaries: Vec, - #[serde(default, skip_serializing_if = "Vec::is_empty")] - middleware: Vec, } #[derive(Debug, Serialize, Deserialize)] @@ -174,8 +172,6 @@ struct NetworkEndpointDef { json_rpc: Option, #[serde(default, skip_serializing_if = "Option::is_none")] mcp: Option, - #[serde(default, skip_serializing_if = "Vec::is_empty")] - middleware: Vec, } // Signature dictated by serde's `skip_serializing_if`, which requires `&T`. @@ -788,7 +784,6 @@ fn to_proto(raw: PolicyFile) -> SandboxPolicy { signing_region: e.signing_region, json_rpc_max_body_bytes: json_rpc_max_body_bytes(&e.json_rpc, &e.mcp), mcp: mcp_options(&e.mcp), - middleware: e.middleware, } }) .collect(), @@ -800,7 +795,6 @@ fn to_proto(raw: PolicyFile) -> SandboxPolicy { ..Default::default() }) .collect(), - middleware: rule.middleware, }; (key, proto_rule) }) @@ -938,7 +932,6 @@ fn from_proto(policy: &SandboxPolicy) -> PolicyFile { signing_region: e.signing_region.clone(), json_rpc, mcp, - middleware: e.middleware.clone(), } }) .collect(), @@ -950,7 +943,6 @@ fn from_proto(policy: &SandboxPolicy) -> PolicyFile { harness: false, }) .collect(), - middleware: rule.middleware.clone(), }; (key.clone(), yaml_rule) }) @@ -1255,6 +1247,16 @@ pub enum PolicyViolation { CredentialSigningWithBodyRewrite { policy_name: String, host: String }, /// A built-in middleware configuration is invalid. InvalidBuiltinMiddlewareConfig { name: String, reason: String }, + /// A middleware configuration is structurally invalid. + InvalidMiddlewareConfig { name: String, reason: String }, + /// Middleware configuration names must be unique. + DuplicateMiddlewareConfigName { name: String }, + /// A middleware selector conflicts with an endpoint that skips TLS inspection. + MiddlewareTlsSkipConflict { + middleware_name: String, + policy_name: String, + host: String, + }, } impl fmt::Display for PolicyViolation { @@ -1319,13 +1321,67 @@ impl fmt::Display for PolicyViolation { and request_body_credential_rewrite set; these options are mutually exclusive" ) } - Self::InvalidBuiltinMiddlewareConfig { name, reason } => { + Self::InvalidBuiltinMiddlewareConfig { name, reason } + | Self::InvalidMiddlewareConfig { name, reason } => { write!(f, "middleware config '{name}' is invalid: {reason}") } + Self::DuplicateMiddlewareConfigName { name } => { + write!(f, "duplicate middleware config '{name}'") + } + Self::MiddlewareTlsSkipConflict { + middleware_name, + policy_name, + host, + } => { + write!( + f, + "middleware config '{middleware_name}' selects network policy \ + '{policy_name}' tls: skip endpoint '{host}'" + ) + } } } } +/// Match a middleware host selector pattern using the runtime's glob semantics. +/// +/// Invalid or empty patterns return an error instead of silently becoming a +/// non-match. +pub fn middleware_host_matches(pattern: &str, host: &str) -> std::result::Result { + if pattern.is_empty() { + return Err("host pattern must not be empty".to_string()); + } + if pattern.chars().any(char::is_whitespace) { + return Err("host pattern must not contain whitespace".to_string()); + } + + let pattern = glob::Pattern::new(&pattern.to_ascii_lowercase()) + .map_err(|error| format!("invalid host pattern: {error}"))?; + Ok(pattern.matches(&host.to_ascii_lowercase())) +} + +fn middleware_selector_matches_host( + middleware: &NetworkMiddlewareConfig, + host: &str, +) -> std::result::Result { + let Some(selector) = &middleware.endpoints else { + return Ok(false); + }; + let matches_include = selector + .include + .iter() + .try_fold(false, |matched, pattern| { + middleware_host_matches(pattern, host).map(|matches| matched || matches) + })?; + let matches_exclude = selector + .exclude + .iter() + .try_fold(false, |matched, pattern| { + middleware_host_matches(pattern, host).map(|matches| matched || matches) + })?; + Ok(matches_include && !matches_exclude) +} + /// Validate that a sandbox policy does not contain unsafe content. /// /// Returns `Ok(())` if the policy is safe, or `Err(violations)` listing all @@ -1340,6 +1396,9 @@ impl fmt::Display for PolicyViolation { /// - Individual path lengths must not exceed [`MAX_PATH_LENGTH`] /// - Total path count must not exceed [`MAX_FILESYSTEM_PATHS`] /// - Network endpoint hosts must not use TLD wildcards (e.g. `*.com`) +/// - Middleware names, implementations, failure modes, selectors, and built-in +/// configurations must be valid +/// - Middleware selectors must not match endpoints that skip TLS inspection pub fn validate_sandbox_policy( policy: &SandboxPolicy, ) -> std::result::Result<(), Vec> { @@ -1454,9 +1513,67 @@ pub fn validate_sandbox_policy( } } + let mut middleware_names = HashSet::new(); for middleware in &policy.network_middlewares { - if middleware.middleware.starts_with("openshell/") { - let config = middleware.config.as_ref().cloned().unwrap_or_default(); + if middleware.name.is_empty() { + violations.push(PolicyViolation::InvalidMiddlewareConfig { + name: middleware.name.clone(), + reason: "name must not be empty".to_string(), + }); + } else if !middleware_names.insert(middleware.name.clone()) { + violations.push(PolicyViolation::DuplicateMiddlewareConfigName { + name: middleware.name.clone(), + }); + } + + if middleware.middleware.is_empty() { + violations.push(PolicyViolation::InvalidMiddlewareConfig { + name: middleware.name.clone(), + reason: "implementation must not be empty".to_string(), + }); + } else if middleware.middleware.starts_with("openshell/") + && middleware.middleware != openshell_supervisor_middleware::BUILTIN_SECRETS + { + violations.push(PolicyViolation::InvalidMiddlewareConfig { + name: middleware.name.clone(), + reason: format!("unsupported built-in '{}'", middleware.middleware), + }); + } + + if !matches!( + middleware.on_error.as_str(), + "" | "fail_closed" | "fail_open" + ) { + violations.push(PolicyViolation::InvalidMiddlewareConfig { + name: middleware.name.clone(), + reason: format!("invalid on_error '{}'", middleware.on_error), + }); + } + + let Some(selector) = &middleware.endpoints else { + violations.push(PolicyViolation::InvalidMiddlewareConfig { + name: middleware.name.clone(), + reason: "endpoint selector is required".to_string(), + }); + continue; + }; + if selector.include.is_empty() { + violations.push(PolicyViolation::InvalidMiddlewareConfig { + name: middleware.name.clone(), + reason: "endpoint selector must include at least one host pattern".to_string(), + }); + } + for pattern in selector.include.iter().chain(&selector.exclude) { + if let Err(reason) = middleware_host_matches(pattern, "validation.invalid") { + violations.push(PolicyViolation::InvalidMiddlewareConfig { + name: middleware.name.clone(), + reason: format!("endpoint selector pattern '{pattern}' is invalid: {reason}"), + }); + } + } + + if middleware.middleware == openshell_supervisor_middleware::BUILTIN_SECRETS { + let config = middleware.config.clone().unwrap_or_default(); if let Err(error) = openshell_supervisor_middleware::validate_builtin_config( &middleware.middleware, &config, @@ -1467,6 +1584,25 @@ pub fn validate_sandbox_policy( }); } } + + for (key, rule) in &policy.network_policies { + let policy_name = if rule.name.is_empty() { + key + } else { + &rule.name + }; + for endpoint in &rule.endpoints { + if endpoint.tls == "skip" + && middleware_selector_matches_host(middleware, &endpoint.host).unwrap_or(false) + { + violations.push(PolicyViolation::MiddlewareTlsSkipConflict { + middleware_name: middleware.name.clone(), + policy_name: policy_name.clone(), + host: endpoint.host.clone(), + }); + } + } + } } if violations.is_empty() { @@ -1608,17 +1744,17 @@ network_middlewares: service: mode: redact max_matches: 2 - - name: endpoint-redactor + - name: secondary-redactor middleware: openshell/secrets + endpoints: + include: ["api.example.com"] network_policies: api: name: api - middleware: ["global-redactor"] endpoints: - host: api.example.com port: 443 protocol: rest - middleware: ["endpoint-redactor"] binaries: - path: /usr/bin/curl "#; @@ -1651,26 +1787,9 @@ network_policies: .fields .contains_key("service") ); - assert_eq!( - proto.network_policies["api"].middleware, - vec!["global-redactor"] - ); - assert_eq!( - proto.network_policies["api"].endpoints[0].middleware, - vec!["endpoint-redactor"] - ); - let yaml_out = serialize_sandbox_policy(&proto).expect("serialize failed"); let reparsed = parse_sandbox_policy(&yaml_out).expect("re-parse failed"); assert_eq!(reparsed.network_middlewares, proto.network_middlewares); - assert_eq!( - reparsed.network_policies["api"].middleware, - vec!["global-redactor"] - ); - assert_eq!( - reparsed.network_policies["api"].endpoints[0].middleware, - vec!["endpoint-redactor"] - ); } #[test] @@ -1803,6 +1922,31 @@ network_policies: assert!(parse_sandbox_policy(yaml).is_err()); } + #[test] + fn parse_rejects_middleware_attachments_on_network_policies_and_endpoints() { + let policy_attachment = r" +version: 1 +network_policies: + api: + middleware: [redact] + endpoints: + - host: api.example.com + port: 443 +"; + assert!(parse_sandbox_policy(policy_attachment).is_err()); + + let endpoint_attachment = r" +version: 1 +network_policies: + api: + endpoints: + - host: api.example.com + port: 443 + middleware: [redact] +"; + assert!(parse_sandbox_policy(endpoint_attachment).is_err()); + } + #[test] fn l7_config_stanza_runtime_fields_use_canonical_schema() { let fields = l7_config_alias_runtime_fields( @@ -1876,6 +2020,19 @@ network_policies: // ---- Policy validation tests ---- + fn middleware_config(name: &str, implementation: &str) -> NetworkMiddlewareConfig { + NetworkMiddlewareConfig { + name: name.into(), + middleware: implementation.into(), + config: None, + on_error: String::new(), + endpoints: Some(MiddlewareEndpointSelector { + include: vec!["api.example.com".into()], + exclude: Vec::new(), + }), + } + } + #[test] fn validate_rejects_root_run_as_user() { let mut policy = restrictive_default_policy(); @@ -1907,21 +2064,17 @@ network_policies: #[test] fn validate_rejects_invalid_builtin_middleware_config() { let mut policy = restrictive_default_policy(); - policy.network_middlewares.push(NetworkMiddlewareConfig { - name: "redact-secrets".into(), - middleware: "openshell/secrets".into(), - config: Some(prost_types::Struct { - fields: std::iter::once(( - "secrets".into(), - prost_types::Value { - kind: Some(prost_types::value::Kind::StringValue("allow".into())), - }, - )) - .collect(), - }), - on_error: String::new(), - endpoints: None, + let mut middleware = middleware_config("redact-secrets", "openshell/secrets"); + middleware.config = Some(prost_types::Struct { + fields: std::iter::once(( + "secrets".into(), + prost_types::Value { + kind: Some(prost_types::value::Kind::StringValue("allow".into())), + }, + )) + .collect(), }); + policy.network_middlewares.push(middleware); let violations = validate_sandbox_policy(&policy).expect_err("invalid config"); assert!(violations.iter().any(|violation| matches!( @@ -1931,6 +2084,134 @@ network_policies: ))); } + #[test] + fn validate_rejects_invalid_middleware_control_fields() { + let cases = [ + ( + middleware_config("", "openshell/secrets"), + "name must not be empty", + ), + ( + middleware_config("redactor", ""), + "implementation must not be empty", + ), + ( + middleware_config("redactor", "openshell/unknown"), + "unsupported built-in", + ), + ( + { + let mut middleware = middleware_config("redactor", "openshell/secrets"); + middleware.on_error = "maybe".into(); + middleware + }, + "invalid on_error", + ), + ( + { + let mut middleware = middleware_config("redactor", "openshell/secrets"); + middleware.endpoints = None; + middleware + }, + "endpoint selector is required", + ), + ( + { + let mut middleware = middleware_config("redactor", "openshell/secrets"); + middleware.endpoints.as_mut().unwrap().include.clear(); + middleware + }, + "must include at least one host pattern", + ), + ]; + + for (middleware, expected) in cases { + let mut policy = restrictive_default_policy(); + policy.network_middlewares.push(middleware); + let errors = validate_sandbox_policy(&policy) + .expect_err("invalid middleware must be rejected") + .into_iter() + .map(|violation| violation.to_string()) + .collect::>() + .join("; "); + assert!( + errors.contains(expected), + "expected {expected:?} in {errors:?}" + ); + } + } + + #[test] + fn validate_rejects_duplicate_middleware_config_names() { + let mut policy = restrictive_default_policy(); + policy + .network_middlewares + .push(middleware_config("redactor", "openshell/secrets")); + policy + .network_middlewares + .push(middleware_config("redactor", "openshell/secrets")); + + let violations = validate_sandbox_policy(&policy).expect_err("duplicate name"); + assert!(violations.iter().any(|violation| matches!( + violation, + PolicyViolation::DuplicateMiddlewareConfigName { name } if name == "redactor" + ))); + } + + #[test] + fn validate_rejects_malformed_middleware_selector_patterns() { + let mut policy = restrictive_default_policy(); + let mut middleware = middleware_config("redactor", "openshell/secrets"); + middleware.endpoints.as_mut().unwrap().include = vec!["api[.example.com".into()]; + policy.network_middlewares.push(middleware); + + let errors = validate_sandbox_policy(&policy) + .expect_err("malformed selector") + .into_iter() + .map(|violation| violation.to_string()) + .collect::>() + .join("; "); + assert!(errors.contains("invalid host pattern"), "{errors}"); + } + + #[test] + fn middleware_host_selector_matching_is_case_insensitive() { + assert!(middleware_host_matches("*.Example.COM", "API.example.com").unwrap()); + assert!(!middleware_host_matches("*.example.com", "example.com").unwrap()); + assert!(middleware_host_matches("*", "deep.api.example.com").unwrap()); + } + + #[test] + fn validate_rejects_middleware_selector_matching_tls_skip_endpoint() { + let mut policy = restrictive_default_policy(); + policy + .network_middlewares + .push(middleware_config("redactor", "openshell/secrets")); + policy.network_policies.insert( + "api".into(), + NetworkPolicyRule { + name: "api".into(), + endpoints: vec![NetworkEndpoint { + host: "api.example.com".into(), + port: 443, + tls: "skip".into(), + ..Default::default() + }], + binaries: Vec::new(), + }, + ); + + let violations = validate_sandbox_policy(&policy).expect_err("tls skip conflict"); + assert!(violations.iter().any(|violation| matches!( + violation, + PolicyViolation::MiddlewareTlsSkipConflict { + middleware_name, + policy_name, + host, + } if middleware_name == "redactor" && policy_name == "api" && host == "api.example.com" + ))); + } + #[test] fn validate_rejects_non_sandbox_user() { let mut policy = restrictive_default_policy(); diff --git a/crates/openshell-policy/src/merge.rs b/crates/openshell-policy/src/merge.rs index 1c63e6ebc..04f390198 100644 --- a/crates/openshell-policy/src/merge.rs +++ b/crates/openshell-policy/src/merge.rs @@ -989,7 +989,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }, ); @@ -1008,7 +1007,6 @@ mod tests { path: "/usr/bin/gh".to_string(), ..Default::default() }], - ..Default::default() }; let result = merge_policy( @@ -1037,7 +1035,6 @@ mod tests { name: "existing".to_string(), endpoints: vec![endpoint("api.github.com", 443)], binaries: vec![advisor_binary("/usr/bin/curl")], - ..Default::default() }, ); @@ -1048,7 +1045,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; let result = merge_policy( @@ -1080,7 +1076,6 @@ mod tests { ..Default::default() }, ], - ..Default::default() }; let result = merge_policy( @@ -1112,7 +1107,6 @@ mod tests { path: "/usr/bin/python".to_string(), ..Default::default() }], - ..Default::default() }, ); @@ -1126,7 +1120,6 @@ mod tests { ..Default::default() }], binaries: vec![advisor_binary("/usr/bin/python")], - ..Default::default() }; let result = merge_policy( @@ -1454,7 +1447,6 @@ mod tests { path: "/usr/bin/gh".to_string(), ..Default::default() }], - ..Default::default() }, ); @@ -1479,7 +1471,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; let merged = merge_policy( @@ -1503,7 +1494,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; // Merge an *unrelated* rule for a different host. The proposed rule @@ -1534,7 +1524,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; let mut policy = restrictive_default_policy(); @@ -1547,7 +1536,6 @@ mod tests { path: "/usr/bin/git".to_string(), ..Default::default() }], - ..Default::default() }, ); @@ -1579,7 +1567,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; // Endpoint exists in the policy but with a *different* binary. The @@ -1595,7 +1582,6 @@ mod tests { path: "/usr/bin/git".to_string(), ..Default::default() }], - ..Default::default() }, ); @@ -1632,7 +1618,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; let mut policy = restrictive_default_policy(); @@ -1652,7 +1637,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }, ); @@ -1680,7 +1664,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; let mut policy = restrictive_default_policy(); @@ -1703,7 +1686,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }, ); @@ -1727,7 +1709,6 @@ mod tests { path: "/usr/bin/git".to_string(), ..Default::default() }], - ..Default::default() }; let merged = merge_policy( @@ -1752,7 +1733,6 @@ mod tests { name: "any_binary_rule".to_string(), endpoints: vec![endpoint("api.github.com", 443)], binaries: vec![], - ..Default::default() }; let mut policy = restrictive_default_policy(); @@ -1765,7 +1745,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }, ); @@ -1823,7 +1802,6 @@ mod tests { path: "/usr/bin/gh".to_string(), ..Default::default() }], - ..Default::default() }; let composed = compose_effective_policy( &SandboxPolicy::default(), @@ -1855,7 +1833,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; let result = merge_policy( composed, @@ -1924,7 +1901,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; let result = merge_policy( policy, diff --git a/crates/openshell-providers/src/profiles.rs b/crates/openshell-providers/src/profiles.rs index 1eb1b54d2..ddfbcaf7d 100644 --- a/crates/openshell-providers/src/profiles.rs +++ b/crates/openshell-providers/src/profiles.rs @@ -450,7 +450,6 @@ impl ProviderTypeProfile { NetworkPolicyRule { name: rule_name.to_string(), endpoints: self.endpoints.iter().map(endpoint_to_proto).collect(), - middleware: Vec::new(), binaries: self.binaries.iter().map(binary_to_proto).collect(), } } @@ -788,7 +787,6 @@ fn endpoint_to_proto(endpoint: &EndpointProfile) -> NetworkEndpoint { request_body_credential_rewrite: endpoint.request_body_credential_rewrite, advisor_proposed: false, persisted_queries: endpoint.persisted_queries.clone(), - middleware: Vec::new(), graphql_persisted_queries: endpoint .graphql_persisted_queries .iter() diff --git a/crates/openshell-sandbox/src/mechanistic_mapper.rs b/crates/openshell-sandbox/src/mechanistic_mapper.rs index bb83ddb66..8ee2fc37f 100644 --- a/crates/openshell-sandbox/src/mechanistic_mapper.rs +++ b/crates/openshell-sandbox/src/mechanistic_mapper.rs @@ -162,7 +162,6 @@ pub fn generate_proposals(summaries: &[DenialSummary]) -> Vec { name: rule_name.clone(), endpoints: vec![endpoint], binaries, - middleware: Vec::new(), }; // Compute confidence. diff --git a/crates/openshell-server/src/grpc/policy.rs b/crates/openshell-server/src/grpc/policy.rs index ad4fdf5ba..d3bc213ba 100644 --- a/crates/openshell-server/src/grpc/policy.rs +++ b/crates/openshell-server/src/grpc/policy.rs @@ -5746,7 +5746,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; let submit = handle_submit_policy_analysis( @@ -5960,7 +5959,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; let submit = handle_submit_policy_analysis( @@ -6077,7 +6075,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; handle_submit_policy_analysis( @@ -6183,7 +6180,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; let mechanistic_submit = handle_submit_policy_analysis( &state, @@ -6261,7 +6257,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; let agent_submit = handle_submit_policy_analysis( &state, @@ -6389,7 +6384,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; handle_submit_policy_analysis( @@ -6490,7 +6484,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; handle_submit_policy_analysis( @@ -6591,7 +6584,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; handle_submit_policy_analysis( @@ -6685,7 +6677,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; handle_submit_policy_analysis( @@ -6770,7 +6761,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; handle_submit_policy_analysis( @@ -6859,7 +6849,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; handle_submit_policy_analysis( @@ -6951,7 +6940,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; handle_submit_policy_analysis( @@ -7038,7 +7026,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; let response = handle_submit_policy_analysis( @@ -7100,7 +7087,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; let chunk = DraftChunkRecord { id: "chunk-provider-prefix".to_string(), @@ -7215,7 +7201,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; handle_submit_policy_analysis( @@ -7312,7 +7297,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; handle_submit_policy_analysis( @@ -7398,7 +7382,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; handle_submit_policy_analysis( @@ -7540,7 +7523,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; handle_submit_policy_analysis( @@ -7666,7 +7648,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; let step1 = handle_submit_policy_analysis( &state, @@ -7708,7 +7689,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; let step2 = handle_submit_policy_analysis( &state, @@ -7840,7 +7820,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; let submit_one = |rule_name: &str, rule: NetworkPolicyRule| { @@ -7949,7 +7928,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; let submit_one = || { let state = state.clone(); @@ -8050,7 +8028,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; let submit = handle_submit_policy_analysis( @@ -8182,7 +8159,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; handle_submit_policy_analysis( @@ -8381,7 +8357,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }, }; @@ -8410,7 +8385,6 @@ mod tests { path: "/usr/bin/node".to_string(), ..Default::default() }], - ..Default::default() }, }; @@ -8439,7 +8413,6 @@ mod tests { path: "/usr/bin/node".to_string(), ..Default::default() }], - ..Default::default() }, }; @@ -8467,7 +8440,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; let chunk = DraftChunkRecord { id: "chunk-1".to_string(), @@ -8536,7 +8508,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }, )) .collect(), @@ -8565,7 +8536,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; let chunk = DraftChunkRecord { id: "chunk-merge".to_string(), @@ -8639,7 +8609,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }, )) .collect(), @@ -8668,7 +8637,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; let chunk = DraftChunkRecord { id: "chunk-new".to_string(), diff --git a/crates/openshell-server/src/grpc/validation.rs b/crates/openshell-server/src/grpc/validation.rs index 0b3548b06..c98ad2d62 100644 --- a/crates/openshell-server/src/grpc/validation.rs +++ b/crates/openshell-server/src/grpc/validation.rs @@ -1613,6 +1613,28 @@ mod tests { assert!(err.message().contains("TLD wildcard")); } + #[test] + fn validate_policy_safety_rejects_invalid_middleware_before_acceptance() { + use openshell_core::proto::{MiddlewareEndpointSelector, NetworkMiddlewareConfig}; + + let mut policy = openshell_policy::restrictive_default_policy(); + policy.network_middlewares.push(NetworkMiddlewareConfig { + name: "redactor".into(), + middleware: "openshell/secrets".into(), + on_error: "maybe".into(), + endpoints: Some(MiddlewareEndpointSelector { + include: vec!["api[.example.com".into()], + exclude: Vec::new(), + }), + ..Default::default() + }); + + let err = validate_policy_safety(&policy).unwrap_err(); + assert_eq!(err.code(), Code::InvalidArgument); + assert!(err.message().contains("invalid on_error")); + assert!(err.message().contains("invalid host pattern")); + } + #[test] fn validate_no_reserved_provider_policy_keys_rejects_reserved_key() { use openshell_core::proto::NetworkPolicyRule; diff --git a/crates/openshell-supervisor-middleware/Cargo.toml b/crates/openshell-supervisor-middleware/Cargo.toml index fdaeb2e82..4ae355894 100644 --- a/crates/openshell-supervisor-middleware/Cargo.toml +++ b/crates/openshell-supervisor-middleware/Cargo.toml @@ -16,10 +16,8 @@ openshell-core = { path = "../openshell-core" } miette = { workspace = true } prost-types = { workspace = true } regex = { workspace = true } -tonic = { workspace = true } - -[dev-dependencies] tokio = { workspace = true } +tonic = { workspace = true } [lints] workspace = true diff --git a/crates/openshell-supervisor-middleware/src/builtins/mod.rs b/crates/openshell-supervisor-middleware/src/builtins/mod.rs index d91ee745e..1db620220 100644 --- a/crates/openshell-supervisor-middleware/src/builtins/mod.rs +++ b/crates/openshell-supervisor-middleware/src/builtins/mod.rs @@ -2,3 +2,28 @@ // SPDX-License-Identifier: Apache-2.0 pub mod secrets; + +use miette::{Result, miette}; +use openshell_core::proto::{HttpRequestEvaluation, HttpRequestResult, MiddlewareBinding}; + +pub fn describe() -> Vec { + vec![secrets::describe()] +} + +pub fn validate_config(binding_id: &str, config: &prost_types::Struct) -> Result<()> { + match binding_id { + secrets::BINDING_ID => secrets::validate_config(config), + other => Err(miette!( + "middleware implementation '{other}' is not available in phase 1" + )), + } +} + +pub fn evaluate_http_request(evaluation: &HttpRequestEvaluation) -> Result { + match evaluation.binding_id.as_str() { + secrets::BINDING_ID => secrets::evaluate_http_request(evaluation), + other => Err(miette!( + "middleware implementation '{other}' is not available in phase 1" + )), + } +} diff --git a/crates/openshell-supervisor-middleware/src/builtins/secrets.rs b/crates/openshell-supervisor-middleware/src/builtins/secrets.rs index 572102559..d88ac080d 100644 --- a/crates/openshell-supervisor-middleware/src/builtins/secrets.rs +++ b/crates/openshell-supervisor-middleware/src/builtins/secrets.rs @@ -5,10 +5,24 @@ use std::collections::HashMap; use std::sync::LazyLock; use miette::{Result, miette}; -use openshell_core::proto::{Decision, Finding, HttpRequestEvaluation, HttpRequestResult}; +use openshell_core::proto::{ + Decision, Finding, HttpRequestEvaluation, HttpRequestResult, MiddlewareBinding, +}; use regex::Regex; -use crate::BUILTIN_SECRETS; +pub const BINDING_ID: &str = "openshell/secrets"; +const OPERATION: &str = "HttpRequest"; +const PHASE: &str = "pre_credentials"; +const MAX_BODY_BYTES: u64 = 256 * 1024; + +pub fn describe() -> MiddlewareBinding { + MiddlewareBinding { + id: BINDING_ID.into(), + operation: OPERATION.into(), + phase: PHASE.into(), + max_body_bytes: MAX_BODY_BYTES, + } +} /// A named secret-detection pattern. The `kind` is an audit-safe label that /// flows into findings so operators can see *what* matched without seeing the @@ -51,7 +65,7 @@ pub fn validate_config(config: &prost_types::Struct) -> Result<()> { if mode != "redact" { return Err(miette!( "{} only supports config.secrets: redact in phase 1", - BUILTIN_SECRETS + BINDING_ID )); } Ok(()) @@ -61,7 +75,7 @@ pub fn evaluate_http_request(evaluation: &HttpRequestEvaluation) -> Result Result<()> { - match implementation { - BUILTIN_SECRETS => builtins::secrets::validate_config(config), - other => Err(miette!( - "middleware implementation '{other}' is not available in phase 1" - )), - } + builtins::validate_config(implementation, config) } #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -85,6 +79,26 @@ impl TryFrom<&NetworkMiddlewareConfig> for ChainEntry { } } +/// A policy-selected middleware config joined with metadata reported by its +/// service's `Describe` call. A missing binding is retained so `on_error` can +/// decide whether the request fails open or closed. +#[derive(Debug, Clone)] +pub struct DescribedChainEntry { + entry: ChainEntry, + binding: Option, + max_body_bytes: usize, +} + +impl DescribedChainEntry { + pub fn max_body_bytes(&self) -> usize { + self.max_body_bytes + } + + pub fn on_error(&self) -> OnError { + self.entry.on_error + } +} + #[derive(Debug, Clone)] pub struct HttpRequestInput { pub request_id: String, @@ -138,15 +152,15 @@ enum OnErrorAction { /// Apply a middleware entry's `on_error` policy after a failure (service error or /// malformed response). Records a `failed` invocation for telemetry in both cases. fn apply_on_error( - entry: &ChainEntry, + entry: &DescribedChainEntry, reason: &str, applied: &mut Vec, ) -> OnErrorAction { - match entry.on_error { + match entry.entry.on_error { OnError::FailOpen => { applied.push(MiddlewareInvocation { - name: entry.name.clone(), - implementation: entry.implementation.clone(), + name: entry.entry.name.clone(), + implementation: entry.entry.implementation.clone(), decision: Decision::Allow, transformed: false, failed: true, @@ -155,8 +169,8 @@ fn apply_on_error( } OnError::FailClosed => { applied.push(MiddlewareInvocation { - name: entry.name.clone(), - implementation: entry.implementation.clone(), + name: entry.entry.name.clone(), + implementation: entry.entry.implementation.clone(), decision: Decision::Deny, transformed: false, failed: true, @@ -168,24 +182,102 @@ fn apply_on_error( #[derive(Clone)] pub struct ChainRunner { + state: Arc, +} + +struct MiddlewareServiceState { service: Arc, + manifest: OnceCell, } +static IN_PROCESS_SERVICE: LazyLock> = LazyLock::new(|| { + Arc::new(MiddlewareServiceState { + service: Arc::new(InProcessMiddlewareService), + manifest: OnceCell::new(), + }) +}); + impl Default for ChainRunner { fn default() -> Self { - Self::new(Arc::new(InProcessMiddlewareService)) + Self { + state: Arc::clone(&IN_PROCESS_SERVICE), + } } } impl ChainRunner { pub fn new(service: Arc) -> Self { - Self { service } + Self { + state: Arc::new(MiddlewareServiceState { + service, + manifest: OnceCell::new(), + }), + } + } + + async fn manifest(&self) -> Result<&MiddlewareManifest> { + self.state + .manifest + .get_or_try_init(|| async { + self.state + .service + .describe(Request::new(())) + .await + .map(tonic::Response::into_inner) + .map_err(|error| { + miette!( + "middleware Describe failed: {}", + safe_reason(&error.to_string()) + ) + }) + }) + .await + } + + pub async fn describe_chain(&self, entries: &[ChainEntry]) -> Result> { + let manifest = self.manifest().await?; + entries + .iter() + .map(|entry| { + let binding = manifest + .bindings + .iter() + .find(|binding| binding.id == entry.implementation) + .cloned(); + let max_body_bytes = binding + .as_ref() + .map(|binding| { + usize::try_from(binding.max_body_bytes).map_err(|_| { + miette!( + "middleware binding '{}' reports a body limit too large for this platform", + binding.id + ) + }) + }) + .transpose()? + .unwrap_or(0); + Ok(DescribedChainEntry { + entry: entry.clone(), + binding, + max_body_bytes, + }) + }) + .collect() } pub async fn evaluate( &self, entries: &[ChainEntry], input: HttpRequestInput, + ) -> Result { + let entries = self.describe_chain(entries).await?; + self.evaluate_described(&entries, input).await + } + + pub async fn evaluate_described( + &self, + entries: &[DescribedChainEntry], + input: HttpRequestInput, ) -> Result { let mut headers = input.headers.clone(); let mut body = input.body.clone(); @@ -195,8 +287,41 @@ impl ChainRunner { let mut applied = Vec::new(); for entry in entries { - let evaluation = build_evaluation(entry, &input, &headers, &body); + let Some(binding) = entry.binding.as_ref() else { + match apply_on_error(entry, "binding_not_described", &mut applied) { + OnErrorAction::FailOpen => continue, + OnErrorAction::FailClosed(reason) => { + return Ok(ChainOutcome { + allowed: false, + reason, + body, + added_headers, + findings, + metadata, + applied, + }); + } + } + }; + if body.len() > entry.max_body_bytes { + match apply_on_error(entry, "request_body_over_capacity", &mut applied) { + OnErrorAction::FailOpen => continue, + OnErrorAction::FailClosed(reason) => { + return Ok(ChainOutcome { + allowed: false, + reason, + body, + added_headers, + findings, + metadata, + applied, + }); + } + } + } + let evaluation = build_evaluation(entry, binding, &input, &headers, &body); let result = match self + .state .service .evaluate_http_request(Request::new(evaluation)) .await @@ -240,6 +365,23 @@ impl ChainRunner { } }; + if result.has_body && result.body.len() > entry.max_body_bytes { + match apply_on_error(entry, "response_body_over_capacity", &mut applied) { + OnErrorAction::FailOpen => continue, + OnErrorAction::FailClosed(reason) => { + return Ok(ChainOutcome { + allowed: false, + reason, + body, + added_headers, + findings, + metadata, + applied, + }); + } + } + } + // A result proposing unsafe header mutations is a malformed response: // route it through `on_error` instead of applying any of it. if validate_header_mutations(&headers, &result.add_headers).is_err() { @@ -268,19 +410,19 @@ impl ChainRunner { } for finding in result.findings { findings.push(NamespacedFinding { - middleware: entry.name.clone(), + middleware: entry.entry.name.clone(), finding, }); } if !result.metadata.is_empty() { metadata.insert( - entry.name.clone(), + entry.entry.name.clone(), result.metadata.clone().into_iter().collect(), ); } applied.push(MiddlewareInvocation { - name: entry.name.clone(), - implementation: entry.implementation.clone(), + name: entry.entry.name.clone(), + implementation: entry.entry.implementation.clone(), decision, transformed, failed: false, @@ -311,21 +453,22 @@ impl ChainRunner { } fn build_evaluation( - entry: &ChainEntry, + entry: &DescribedChainEntry, + binding: &MiddlewareBinding, input: &HttpRequestInput, headers: &BTreeMap, body: &[u8], ) -> HttpRequestEvaluation { HttpRequestEvaluation { api_version: API_VERSION.into(), - binding_id: entry.implementation.clone(), - phase: PRE_CREDENTIALS_PHASE.into(), + binding_id: binding.id.clone(), + phase: binding.phase.clone(), context: Some(RequestContext { request_id: input.request_id.clone(), sandbox_id: input.sandbox_id.clone(), originating_process: None, }), - config: Some(entry.config.clone()), + config: Some(entry.entry.config.clone()), target: Some(HttpRequestTarget { scheme: input.scheme.clone(), host: input.host.clone(), @@ -436,11 +579,16 @@ mod tests { } } - #[test] - fn phase_one_evaluation_omits_originating_process() { - let entry = entry("redact", OnError::FailClosed); + #[tokio::test] + async fn phase_one_evaluation_omits_originating_process() { + let entries = ChainRunner::default() + .describe_chain(&[entry("redact", OnError::FailClosed)]) + .await + .expect("describe chain"); + let entry = &entries[0]; + let binding = entry.binding.as_ref().expect("described binding"); let input = input("payload"); - let evaluation = build_evaluation(&entry, &input, &BTreeMap::new(), b"payload"); + let evaluation = build_evaluation(entry, binding, &input, &BTreeMap::new(), b"payload"); assert!( evaluation @@ -527,8 +675,9 @@ mod tests { .into_inner(); assert_eq!(manifest.api_version, API_VERSION); assert_eq!(manifest.bindings[0].id, BUILTIN_SECRETS); - assert_eq!(manifest.bindings[0].operation, HTTP_REQUEST_OPERATION); - assert_eq!(manifest.bindings[0].phase, PRE_CREDENTIALS_PHASE); + assert_eq!(manifest.bindings[0].operation, "HttpRequest"); + assert_eq!(manifest.bindings[0].phase, "pre_credentials"); + assert_eq!(manifest.bindings[0].max_body_bytes, 256 * 1024); } #[test] @@ -561,6 +710,8 @@ mod tests { /// evaluation. Used to exercise chain behavior the built-in cannot produce /// (explicit deny, metadata, findings, unsafe header mutations). struct ScriptedService { + binding_id: String, + max_body_bytes: u64, result: openshell_core::proto::HttpRequestResult, } @@ -569,13 +720,18 @@ mod tests { async fn describe( &self, _request: Request<()>, - ) -> std::result::Result< - tonic::Response, - tonic::Status, - > { - Ok(tonic::Response::new( - openshell_core::proto::MiddlewareManifest::default(), - )) + ) -> std::result::Result, tonic::Status> { + Ok(tonic::Response::new(MiddlewareManifest { + api_version: API_VERSION.into(), + name: "test/middleware".into(), + service_version: "test".into(), + bindings: vec![MiddlewareBinding { + id: self.binding_id.clone(), + operation: "HttpRequest".into(), + phase: "pre_credentials".into(), + max_body_bytes: self.max_body_bytes, + }], + })) } async fn validate_config( @@ -604,6 +760,14 @@ mod tests { } } + fn scripted_service(result: openshell_core::proto::HttpRequestResult) -> ScriptedService { + ScriptedService { + binding_id: BUILTIN_SECRETS.into(), + max_body_bytes: 256 * 1024, + result, + } + } + fn allow_result() -> openshell_core::proto::HttpRequestResult { openshell_core::proto::HttpRequestResult { decision: Decision::Allow as i32, @@ -617,14 +781,49 @@ mod tests { } #[tokio::test] - async fn deny_decision_short_circuits_chain() { + async fn descriptors_are_resolved_from_any_middleware_service() { let runner = ChainRunner::new(Arc::new(ScriptedService { - result: openshell_core::proto::HttpRequestResult { + binding_id: "example/redactor".into(), + max_body_bytes: 4096, + result: allow_result(), + })); + let entry = ChainEntry { + name: "external".into(), + implementation: "example/redactor".into(), + config: prost_types::Struct::default(), + on_error: OnError::FailClosed, + }; + + let described = runner + .describe_chain(std::slice::from_ref(&entry)) + .await + .expect("describe external middleware"); + assert_eq!(described[0].max_body_bytes(), 4096); + assert_eq!( + described[0] + .binding + .as_ref() + .expect("described binding") + .phase, + "pre_credentials" + ); + + let outcome = runner + .evaluate_described(&described, input("hello")) + .await + .expect("evaluate external middleware"); + assert!(outcome.allowed); + } + + #[tokio::test] + async fn deny_decision_short_circuits_chain() { + let runner = ChainRunner::new(Arc::new(scripted_service( + openshell_core::proto::HttpRequestResult { decision: Decision::Deny as i32, reason: "blocked_by_policy".into(), ..allow_result() }, - })); + ))); let outcome = runner .evaluate( &[ @@ -645,8 +844,8 @@ mod tests { #[tokio::test] async fn metadata_and_findings_are_namespaced_per_config() { - let runner = ChainRunner::new(Arc::new(ScriptedService { - result: openshell_core::proto::HttpRequestResult { + let runner = ChainRunner::new(Arc::new(scripted_service( + openshell_core::proto::HttpRequestResult { findings: vec![Finding { r#type: "pii.email".into(), label: "email address".into(), @@ -658,7 +857,7 @@ mod tests { .collect(), ..allow_result() }, - })); + ))); let outcome = runner .evaluate( &[ @@ -683,16 +882,14 @@ mod tests { } fn unsafe_header_service() -> ScriptedService { - ScriptedService { - result: openshell_core::proto::HttpRequestResult { - add_headers: std::iter::once(( - "x-openshell-middleware-inject".to_string(), - "ok\r\nHost: evil".to_string(), - )) - .collect(), - ..allow_result() - }, - } + scripted_service(openshell_core::proto::HttpRequestResult { + add_headers: std::iter::once(( + "x-openshell-middleware-inject".to_string(), + "ok\r\nHost: evil".to_string(), + )) + .collect(), + ..allow_result() + }) } #[tokio::test] @@ -724,13 +921,79 @@ mod tests { } #[tokio::test] - async fn unspecified_decision_uses_fail_closed() { + async fn oversized_replacement_body_honors_on_error() { let runner = ChainRunner::new(Arc::new(ScriptedService { + binding_id: BUILTIN_SECRETS.into(), + max_body_bytes: 4, result: openshell_core::proto::HttpRequestResult { - decision: Decision::Unspecified as i32, + body: b"too large".to_vec(), + has_body: true, ..allow_result() }, })); + let fail_open = entry("small", OnError::FailOpen); + let mut fail_closed = fail_open.clone(); + fail_closed.on_error = OnError::FailClosed; + + let open_outcome = runner + .evaluate(&[fail_open], input("safe")) + .await + .expect("fail-open evaluation"); + assert!(open_outcome.allowed); + assert_eq!(open_outcome.body, b"safe"); + assert!(open_outcome.applied[0].failed); + + let closed_outcome = runner + .evaluate(&[fail_closed], input("safe")) + .await + .expect("fail-closed evaluation"); + assert!(!closed_outcome.allowed); + assert_eq!( + closed_outcome.reason, + "middleware_failed: response_body_over_capacity" + ); + assert!(closed_outcome.applied[0].failed); + } + + #[tokio::test] + async fn oversized_request_body_honors_on_error() { + let runner = ChainRunner::new(Arc::new(ScriptedService { + binding_id: BUILTIN_SECRETS.into(), + max_body_bytes: 4, + result: allow_result(), + })); + let fail_open = entry("small", OnError::FailOpen); + let mut fail_closed = fail_open.clone(); + fail_closed.on_error = OnError::FailClosed; + + let open_outcome = runner + .evaluate(&[fail_open], input("hello")) + .await + .expect("fail-open evaluation"); + assert!(open_outcome.allowed); + assert_eq!(open_outcome.body, b"hello"); + assert!(open_outcome.applied[0].failed); + + let closed_outcome = runner + .evaluate(&[fail_closed], input("hello")) + .await + .expect("fail-closed evaluation"); + assert!(!closed_outcome.allowed); + assert_eq!( + closed_outcome.reason, + "middleware_failed: request_body_over_capacity" + ); + assert!(closed_outcome.applied[0].failed); + } + + #[tokio::test] + async fn unspecified_decision_uses_fail_closed() { + let runner = ChainRunner::new(Arc::new(scripted_service( + openshell_core::proto::HttpRequestResult { + decision: Decision::Unspecified as i32, + ..allow_result() + }, + ))); let outcome = runner .evaluate(&[entry("redact", OnError::FailClosed)], input("hello")) diff --git a/crates/openshell-supervisor-middleware/src/service.rs b/crates/openshell-supervisor-middleware/src/service.rs index cbd9231cd..51df8d070 100644 --- a/crates/openshell-supervisor-middleware/src/service.rs +++ b/crates/openshell-supervisor-middleware/src/service.rs @@ -3,15 +3,12 @@ use openshell_core::proto::middleware::v1::supervisor_middleware_server::SupervisorMiddleware; use openshell_core::proto::{ - HttpRequestEvaluation, HttpRequestResult, MiddlewareBinding, MiddlewareManifest, - ValidateConfigRequest, ValidateConfigResponse, + HttpRequestEvaluation, HttpRequestResult, MiddlewareManifest, ValidateConfigRequest, + ValidateConfigResponse, }; use tonic::{Request, Response, Status}; -use crate::{ - API_VERSION, BUILTIN_SECRETS, HTTP_REQUEST_OPERATION, PRE_CREDENTIALS_PHASE, builtins, - safe_reason, validate_builtin_config, -}; +use crate::{API_VERSION, builtins, safe_reason, validate_builtin_config}; #[derive(Debug, Default)] pub struct InProcessMiddlewareService; @@ -26,11 +23,7 @@ impl SupervisorMiddleware for InProcessMiddlewareService { api_version: API_VERSION.into(), name: "openshell/in-process".into(), service_version: env!("CARGO_PKG_VERSION").into(), - bindings: vec![MiddlewareBinding { - id: BUILTIN_SECRETS.into(), - operation: HTTP_REQUEST_OPERATION.into(), - phase: PRE_CREDENTIALS_PHASE.into(), - }], + bindings: builtins::describe(), })) } @@ -58,13 +51,8 @@ impl SupervisorMiddleware for InProcessMiddlewareService { request: Request, ) -> Result, Status> { let request = request.into_inner(); - let result = match request.binding_id.as_str() { - BUILTIN_SECRETS => builtins::secrets::evaluate_http_request(&request), - other => Err(miette::miette!( - "middleware implementation '{other}' is not available in phase 1" - )), - } - .map_err(|err| Status::invalid_argument(safe_reason(&err.to_string())))?; + let result = builtins::evaluate_http_request(&request) + .map_err(|err| Status::invalid_argument(safe_reason(&err.to_string())))?; Ok(Response::new(result)) } } diff --git a/crates/openshell-supervisor-network/data/sandbox-policy.rego b/crates/openshell-supervisor-network/data/sandbox-policy.rego index 52f6f1046..fcc5838e1 100644 --- a/crates/openshell-supervisor-network/data/sandbox-policy.rego +++ b/crates/openshell-supervisor-network/data/sandbox-policy.rego @@ -842,23 +842,6 @@ _policy_endpoint_configs(policy) := [ep | endpoint_has_extended_config(ep) ] -# Collect matching endpoint identities across all policies. Iterates over -# _matching_policy_names (a set, safe from regorus variable collisions) then -# returns the selected policy name plus endpoint index/path. Rust uses that -# identity to look up middleware attachment from policy data. -_matching_endpoint_contexts := [ctx | - some pname - _matching_policy_names[pname] - policy := data.network_policies[pname] - ep := policy.endpoints[i] - endpoint_matches_request(ep, input.network) - ctx := { - "policy": pname, - "endpoint_index": i, - "endpoint_path": object.get(ep, "path", ""), - } -] - _matching_endpoint_configs := [cfg | some pname _matching_policy_names[pname] @@ -871,8 +854,6 @@ matched_endpoint_config := _matching_endpoint_configs[0] if { count(_matching_endpoint_configs) > 0 } -network_policies := object.get(data, "network_policies", {}) - network_middlewares := object.get(data, "network_middlewares", []) _policy_has_exact_declared_endpoint(policy) if { diff --git a/crates/openshell-supervisor-network/src/l7/relay.rs b/crates/openshell-supervisor-network/src/l7/relay.rs index c773fdcf4..8383b6bb2 100644 --- a/crates/openshell-supervisor-network/src/l7/relay.rs +++ b/crates/openshell-supervisor-network/src/l7/relay.rs @@ -453,8 +453,7 @@ where let _ = &eval_target; if allowed || (config.enforcement == EnforcementMode::Audit && !force_deny) { - let chain = - engine.query_middleware_chain(&middleware_network_input(ctx), &req.target)?; + let chain = engine.query_middleware_chain(&middleware_network_input(ctx))?; let req = match apply_middleware_chain(req, client, ctx, chain, engine.generation_guard()) .await? @@ -773,20 +772,45 @@ pub(crate) enum MiddlewareApplyResult { Denied(String), } +fn middleware_chain_body_limit( + chain: &[openshell_supervisor_middleware::DescribedChainEntry], +) -> Option { + chain + .iter() + .map(openshell_supervisor_middleware::DescribedChainEntry::max_body_bytes) + .min() +} + pub(crate) async fn apply_middleware_chain( req: crate::l7::provider::L7Request, client: &mut C, ctx: &L7EvalContext, chain: Vec, generation_guard: &PolicyGenerationGuard, +) -> Result { + apply_middleware_chain_for_scheme(req, client, ctx, "https", chain, generation_guard).await +} + +pub(crate) async fn apply_middleware_chain_for_scheme( + req: crate::l7::provider::L7Request, + client: &mut C, + ctx: &L7EvalContext, + scheme: &str, + chain: Vec, + generation_guard: &PolicyGenerationGuard, ) -> Result { if chain.is_empty() { return Ok(MiddlewareApplyResult::Allowed(req)); } + let runner = openshell_supervisor_middleware::ChainRunner::default(); + let chain = runner.describe_chain(&chain).await?; + let max_body_bytes = + middleware_chain_body_limit(&chain).expect("non-empty middleware chain has a body limit"); let buffered = match crate::l7::rest::buffer_request_body_for_middleware( &req, client, Some(generation_guard), + max_body_bytes, ) .await? { @@ -797,21 +821,8 @@ pub(crate) async fn apply_middleware_chain, + query: String, + body: Vec, +) -> openshell_supervisor_middleware::HttpRequestInput { + openshell_supervisor_middleware::HttpRequestInput { + request_id: uuid::Uuid::new_v4().to_string(), + sandbox_id: openshell_ocsf::ctx::ctx().sandbox_id.clone(), + scheme: scheme.into(), + host: ctx.host.clone(), + port: ctx.port, + method: req.action.clone(), + path: req.target.clone(), + query, + headers, + body, + } +} + fn raw_query_from_request_headers(headers: &[u8]) -> Result { let header_str = std::str::from_utf8(headers).map_err(|_| miette!("HTTP headers contain invalid UTF-8"))?; @@ -846,12 +879,12 @@ fn raw_query_from_request_headers(headers: &[u8]) -> Result { fn resolve_unbuffered_body( ctx: &L7EvalContext, req: crate::l7::provider::L7Request, - chain: &[openshell_supervisor_middleware::ChainEntry], + chain: &[openshell_supervisor_middleware::DescribedChainEntry], recoverable: bool, ) -> MiddlewareApplyResult { let all_fail_open = chain .iter() - .all(|entry| entry.on_error == openshell_supervisor_middleware::OnError::FailOpen); + .all(|entry| entry.on_error() == openshell_supervisor_middleware::OnError::FailOpen); if recoverable && all_fail_open { emit_middleware_body_unavailable(ctx, false); return MiddlewareApplyResult::Allowed(req); @@ -1187,8 +1220,7 @@ where let _ = &eval_target; if allowed || config.enforcement == EnforcementMode::Audit { - let chain = - engine.query_middleware_chain(&middleware_network_input(ctx), &req.target)?; + let chain = engine.query_middleware_chain(&middleware_network_input(ctx))?; let req = match apply_middleware_chain(req, client, ctx, chain, engine.generation_guard()) .await? @@ -1457,8 +1489,7 @@ where } if allowed || (config.enforcement == EnforcementMode::Audit && !force_deny) { - let chain = - engine.query_middleware_chain(&middleware_network_input(ctx), &req.target)?; + let chain = engine.query_middleware_chain(&middleware_network_input(ctx))?; let req = match apply_middleware_chain(req, client, ctx, chain, engine.generation_guard()) .await? @@ -1682,8 +1713,7 @@ where let _ = &eval_target; if allowed || (config.enforcement == EnforcementMode::Audit && !force_deny) { - let chain = - engine.query_middleware_chain(&middleware_network_input(ctx), &req.target)?; + let chain = engine.query_middleware_chain(&middleware_network_input(ctx))?; let req = match apply_middleware_chain(req, client, ctx, chain, engine.generation_guard()) .await? @@ -2136,8 +2166,7 @@ where let req = if let Some(engine) = middleware_engine { let input = middleware_network_input(ctx); - let (chain, generation) = - engine.query_middleware_chain_with_generation(&input, &req.target)?; + let (chain, generation) = engine.query_middleware_chain_with_generation(&input)?; if generation != generation_guard.captured_generation() { return Ok(()); } @@ -2326,10 +2355,11 @@ network_middlewares: - name: request-middleware middleware: {middleware_impl} on_error: {on_error} + endpoints: + include: ["api.example.test"] network_policies: rest_api: name: rest_api - middleware: ["request-middleware"] endpoints: - host: api.example.test port: 8080 @@ -2785,10 +2815,11 @@ network_middlewares: - name: request-middleware middleware: example/unavailable on_error: fail_closed + endpoints: + include: ["api.example.test"] network_policies: jsonrpc_api: name: jsonrpc_api - middleware: ["request-middleware"] endpoints: - host: api.example.test port: 443 @@ -2929,8 +2960,8 @@ network_policies: .unwrap(); } - #[test] - fn over_capacity_resolution_honors_on_error() { + #[tokio::test] + async fn over_capacity_resolution_honors_on_error() { use openshell_supervisor_middleware::{ChainEntry, OnError}; let ctx = L7EvalContext { @@ -2963,21 +2994,31 @@ network_policies: ..fail_open.clone() }; + let runner = openshell_supervisor_middleware::ChainRunner::default(); + let open_chain = runner + .describe_chain(std::slice::from_ref(&fail_open)) + .await + .expect("describe fail-open chain"); + let mixed_chain = runner + .describe_chain(&[fail_open.clone(), fail_closed]) + .await + .expect("describe mixed chain"); + // Recoverable (Content-Length over cap, nothing consumed) + all fail-open // -> stream through unprocessed. assert!(matches!( - resolve_unbuffered_body(&ctx, req(), std::slice::from_ref(&fail_open), true), + resolve_unbuffered_body(&ctx, req(), &open_chain, true), MiddlewareApplyResult::Allowed(_) )); // Any fail-closed entry -> deny. assert!(matches!( - resolve_unbuffered_body(&ctx, req(), &[fail_open.clone(), fail_closed], true), + resolve_unbuffered_body(&ctx, req(), &mixed_chain, true), MiddlewareApplyResult::Denied(_) )); // Not recoverable (chunked overflow already consumed bytes) -> deny even // when every entry is fail-open. assert!(matches!( - resolve_unbuffered_body(&ctx, req(), &[fail_open], false), + resolve_unbuffered_body(&ctx, req(), &open_chain, false), MiddlewareApplyResult::Denied(_) )); } @@ -2992,6 +3033,40 @@ network_policies: assert_eq!(query, "token=a%2Bb&scope=private"); } + #[test] + fn middleware_request_input_preserves_plain_http_scheme() { + let req = crate::l7::provider::L7Request { + action: "POST".into(), + target: "/v1/messages".into(), + query_params: std::collections::HashMap::new(), + raw_header: Vec::new(), + body_length: crate::l7::provider::BodyLength::None, + }; + let ctx = L7EvalContext { + host: "api.example.test".into(), + port: 80, + policy_name: "api".into(), + binary_path: "/usr/bin/curl".into(), + ancestors: Vec::new(), + cmdline_paths: Vec::new(), + secret_resolver: None, + activity_tx: None, + dynamic_credentials: None, + token_grant_resolver: None, + }; + + let input = middleware_request_input( + "http", + &req, + &ctx, + BTreeMap::new(), + String::new(), + Vec::new(), + ); + + assert_eq!(input.scheme, "http"); + } + /// Tracing layer that captures emitted `OcsfEvent`s for assertions. struct OcsfCaptureLayer(Arc>>); @@ -3096,16 +3171,17 @@ network_policies: #[tokio::test] async fn passthrough_relay_runs_middleware_redaction() { // A no-protocol endpoint takes the credential-injection passthrough path; - // policy-level middleware must still inspect and redact its body. + // host-selected middleware must still inspect and redact its body. let data = r#" network_middlewares: - name: request-middleware middleware: openshell/secrets on_error: fail_closed + endpoints: + include: ["api.example.test"] network_policies: passthrough_api: name: passthrough_api - middleware: ["request-middleware"] endpoints: - host: api.example.test port: 8080 @@ -3197,10 +3273,11 @@ network_middlewares: - name: request-middleware middleware: example/unavailable on_error: fail_closed + endpoints: + include: ["gateway.example.test"] network_policies: ws_api: name: ws_api - middleware: ["request-middleware"] endpoints: - host: gateway.example.test port: 443 diff --git a/crates/openshell-supervisor-network/src/l7/rest.rs b/crates/openshell-supervisor-network/src/l7/rest.rs index 19f73e2ad..15825d1b2 100644 --- a/crates/openshell-supervisor-network/src/l7/rest.rs +++ b/crates/openshell-supervisor-network/src/l7/rest.rs @@ -27,7 +27,19 @@ const MAX_REWRITE_BODY_BYTES: usize = 256 * 1024; /// Maximum body bytes for `SigV4` body-signing mode. Larger than the credential /// rewrite limit because Bedrock payloads can be several megabytes. const MAX_SIGV4_BODY_BYTES: usize = 10 * 1024 * 1024; -pub(crate) const MAX_MIDDLEWARE_BODY_BYTES: usize = MAX_REWRITE_BODY_BYTES; +#[cfg(test)] +async fn max_middleware_body_bytes() -> usize { + let chain = openshell_supervisor_middleware::ChainRunner::default() + .describe_chain(&[openshell_supervisor_middleware::ChainEntry { + name: "test".into(), + implementation: openshell_supervisor_middleware::BUILTIN_SECRETS.into(), + config: prost_types::Struct::default(), + on_error: openshell_supervisor_middleware::OnError::FailClosed, + }]) + .await + .expect("describe built-in middleware"); + chain[0].max_body_bytes() +} const RELAY_BUF_SIZE: usize = 8192; const HTTP_METHOD_PREFIXES: &[&[u8]] = &[ b"GET ", @@ -820,6 +832,7 @@ pub(crate) async fn buffer_request_body_for_middleware( req: &L7Request, client: &mut C, generation_guard: Option<&PolicyGenerationGuard>, + max_body_bytes: usize, ) -> Result { let header_end = req .raw_header @@ -840,7 +853,7 @@ pub(crate) async fn buffer_request_body_for_middleware( let Ok(len) = usize::try_from(len) else { return Ok(BufferResult::OverCapacity { recoverable: true }); }; - if len > MAX_MIDDLEWARE_BODY_BYTES { + if len > max_body_bytes { return Ok(BufferResult::OverCapacity { recoverable: true }); } let initial_len = already_read.len().min(len); @@ -869,14 +882,18 @@ pub(crate) async fn buffer_request_body_for_middleware( } BodyLength::Chunked => { // Chunked bodies are decoded incrementally into the payload bytes - // middleware expects. On overflow, we have already consumed wire - // bytes from the client stream and cannot re-enter the normal raw - // relay path without a separate splice-through buffer. - Ok(collect_chunked_body(client, already_read, generation_guard) - .await - .map_or(BufferResult::OverCapacity { recoverable: false }, |body| { - BufferResult::Buffered(BufferedRequestBody { headers, body }) - })) + // middleware expects, but the middleware cap counts the complete + // wire representation, including framing and trailers. On overflow, + // we have already consumed wire bytes from the client stream and + // cannot re-enter the normal raw relay path without a separate + // splice-through buffer. + Ok( + collect_chunked_body(client, already_read, generation_guard, Some(max_body_bytes)) + .await + .map_or(BufferResult::OverCapacity { recoverable: false }, |body| { + BufferResult::Buffered(BufferedRequestBody { headers, body }) + }), + ) } } } @@ -953,7 +970,7 @@ async fn collect_and_rewrite_request_body( Ok(PreparedRequestBody { headers, body }) } BodyLength::Chunked => { - let body = collect_chunked_body(client, already_read, generation_guard).await?; + let body = collect_chunked_body(client, already_read, generation_guard, None).await?; let (mut headers, body) = rewrite_buffered_body(rewritten_headers, original_header_str, body, resolver)?; headers = set_content_length(&headers, body.len())?; @@ -1125,15 +1142,19 @@ async fn collect_chunked_body( client: &mut C, already_read: &[u8], generation_guard: Option<&PolicyGenerationGuard>, + max_wire_bytes: Option, ) -> Result> { - let mut buffered_pos = 0usize; + let mut read_state = ChunkedReadState { + buffered_pos: 0, + wire_bytes: 0, + max_wire_bytes, + }; let mut body = Vec::new(); loop { - let size_line = - read_chunked_line(client, already_read, &mut buffered_pos, generation_guard) - .await - .map_err(|e| miette!("Chunked body ended before chunk-size line: {e}"))?; + let size_line = read_chunked_line(client, already_read, &mut read_state, generation_guard) + .await + .map_err(|e| miette!("Chunked body ended before chunk-size line: {e}"))?; let size_line = std::str::from_utf8(&size_line) .into_diagnostic() .map_err(|_| miette!("Invalid UTF-8 in chunk-size line"))?; @@ -1149,7 +1170,7 @@ async fn collect_chunked_body( if chunk_size == 0 { loop { let trailer_line = - read_chunked_line(client, already_read, &mut buffered_pos, generation_guard) + read_chunked_line(client, already_read, &mut read_state, generation_guard) .await .map_err(|e| { miette!("Chunked body ended before trailer terminator: {e}") @@ -1168,7 +1189,7 @@ async fn collect_chunked_body( read_buffered_exact( client, already_read, - &mut buffered_pos, + &mut read_state, chunk_size, &mut body, generation_guard, @@ -1180,7 +1201,7 @@ async fn collect_chunked_body( read_buffered_exact( client, already_read, - &mut buffered_pos, + &mut read_state, 2, &mut chunk_crlf, generation_guard, @@ -1193,15 +1214,21 @@ async fn collect_chunked_body( } } +struct ChunkedReadState { + buffered_pos: usize, + wire_bytes: usize, + max_wire_bytes: Option, +} + async fn read_chunked_line( client: &mut C, already_read: &[u8], - buffered_pos: &mut usize, + state: &mut ChunkedReadState, generation_guard: Option<&PolicyGenerationGuard>, ) -> Result> { let mut line = Vec::new(); loop { - let byte = read_buffered_byte(client, already_read, buffered_pos, generation_guard).await?; + let byte = read_buffered_byte(client, already_read, state, generation_guard).await?; line.push(byte); if line.len() > MAX_REWRITE_BODY_BYTES { return Err(miette!( @@ -1218,13 +1245,13 @@ async fn read_chunked_line( async fn read_buffered_exact( client: &mut C, already_read: &[u8], - buffered_pos: &mut usize, + state: &mut ChunkedReadState, len: usize, out: &mut Vec, generation_guard: Option<&PolicyGenerationGuard>, ) -> Result<()> { for _ in 0..len { - let byte = read_buffered_byte(client, already_read, buffered_pos, generation_guard).await?; + let byte = read_buffered_byte(client, already_read, state, generation_guard).await?; out.push(byte); } Ok(()) @@ -1233,18 +1260,30 @@ async fn read_buffered_exact( async fn read_buffered_byte( client: &mut C, already_read: &[u8], - buffered_pos: &mut usize, + state: &mut ChunkedReadState, generation_guard: Option<&PolicyGenerationGuard>, ) -> Result { - if *buffered_pos < already_read.len() { - let byte = already_read[*buffered_pos]; - *buffered_pos += 1; - return Ok(byte); - } - let byte = client.read_u8().await.into_diagnostic()?; - if let Some(guard) = generation_guard { - guard.ensure_current()?; + if state + .max_wire_bytes + .is_some_and(|max| state.wire_bytes >= max) + { + return Err(miette!( + "chunked body wire representation exceeds middleware buffer limit" + )); } + + let byte = if state.buffered_pos < already_read.len() { + let byte = already_read[state.buffered_pos]; + state.buffered_pos += 1; + byte + } else { + let byte = client.read_u8().await.into_diagnostic()?; + if let Some(guard) = generation_guard { + guard.ensure_current()?; + } + byte + }; + state.wire_bytes += 1; Ok(byte) } @@ -3264,6 +3303,7 @@ mod tests { &mut client, b"5\r\nhello\r\n6;ext=value\r\n world\r\n0\r\nx-checksum: abc\r\n\r\n", None, + None, ) .await .expect("chunked body should decode"); @@ -3271,6 +3311,40 @@ mod tests { assert_eq!(body, b"hello world"); } + #[tokio::test] + async fn middleware_chunked_wire_body_at_cap_is_allowed() { + let max_body_bytes = max_middleware_body_bytes().await; + let payload_len = max_body_bytes - 14; + let mut wire = format!("{payload_len:x}\r\n").into_bytes(); + wire.extend(std::iter::repeat_n(b'x', payload_len)); + wire.extend_from_slice(b"\r\n0\r\n\r\n"); + assert_eq!(wire.len(), max_body_bytes); + + let body = collect_chunked_body(&mut tokio::io::empty(), &wire, None, Some(max_body_bytes)) + .await + .expect("wire representation at the cap should be allowed"); + + assert_eq!(body.len(), payload_len); + } + + #[tokio::test] + async fn middleware_chunked_wire_body_over_cap_is_rejected() { + let max_body_bytes = max_middleware_body_bytes().await; + let payload_len = max_body_bytes - 13; + let mut wire = format!("{payload_len:x}\r\n").into_bytes(); + wire.extend(std::iter::repeat_n(b'x', payload_len)); + wire.extend_from_slice(b"\r\n0\r\n\r\n"); + assert_eq!(wire.len(), max_body_bytes + 1); + assert!(payload_len < max_body_bytes); + + let error = + collect_chunked_body(&mut tokio::io::empty(), &wire, None, Some(max_body_bytes)) + .await + .expect_err("wire framing over the cap must be rejected"); + + assert!(error.to_string().contains("wire representation")); + } + /// SEC-009: Bare LF in headers enables header injection. #[tokio::test] async fn reject_bare_lf_in_headers() { diff --git a/crates/openshell-supervisor-network/src/opa.rs b/crates/openshell-supervisor-network/src/opa.rs index 3efec0212..c4e773996 100644 --- a/crates/openshell-supervisor-network/src/opa.rs +++ b/crates/openshell-supervisor-network/src/opa.rs @@ -135,17 +135,13 @@ impl TunnelPolicyEngine { &self.engine } - /// Query the ordered middleware chain for a request path within this tunnel. - pub fn query_middleware_chain( - &self, - input: &NetworkInput, - request_path: &str, - ) -> Result> { + /// Query the ordered middleware chain for a destination within this tunnel. + pub fn query_middleware_chain(&self, input: &NetworkInput) -> Result> { let mut engine = self .engine .lock() .map_err(|_| miette::miette!("OPA engine lock poisoned"))?; - query_middleware_chain_locked(&mut engine, input, request_path) + query_middleware_chain_locked(&mut engine, input) } } @@ -208,21 +204,21 @@ impl OpaEngine { /// gap between user-specified symlink paths (e.g., `/usr/bin/python3`) and /// kernel-resolved canonical paths (e.g., `/usr/bin/python3.11`). pub fn from_proto_with_pid(proto: &ProtoSandboxPolicy, entrypoint_pid: u32) -> Result { + if let Err(violations) = openshell_policy::validate_sandbox_policy(proto) { + let errors = violations + .iter() + .map(ToString::to_string) + .collect::>() + .join("\n"); + return Err(miette::miette!("policy validation failed:\n{errors}")); + } + let data_json_str = proto_to_opa_data_json(proto, entrypoint_pid); // Parse back to Value for preprocessing, then re-serialize let mut data: serde_json::Value = serde_json::from_str(&data_json_str) .map_err(|e| miette::miette!("internal: failed to parse proto JSON: {e}"))?; - // Validate BEFORE expanding presets - let middleware_errors = validate_middleware_policies(&data); - if !middleware_errors.is_empty() { - return Err(miette::miette!( - "middleware policy validation failed:\n{}", - middleware_errors.join("\n") - )); - } - let (errors, warnings) = crate::l7::validate_l7_policies(&data); for w in &warnings { openshell_ocsf::ocsf_emit!( @@ -571,18 +567,17 @@ impl OpaEngine { } } - /// Query the ordered middleware chain for a parsed HTTP request path. + /// Query the ordered middleware chain for an admitted destination. pub fn query_middleware_chain_with_generation( &self, input: &NetworkInput, - request_path: &str, ) -> Result<(Vec, u64)> { let mut engine = self .engine .lock() .map_err(|_| miette::miette!("OPA engine lock poisoned"))?; let generation = self.current_generation(); - let chain = query_middleware_chain_locked(&mut engine, input, request_path)?; + let chain = query_middleware_chain_locked(&mut engine, input)?; Ok((chain, generation)) } @@ -749,17 +744,9 @@ fn network_input_json(input: &NetworkInput) -> serde_json::Value { }) } -#[derive(Debug, Clone)] -struct MatchedEndpointContext { - policy_name: String, - endpoint_index: usize, - endpoint_path: String, -} - fn query_middleware_chain_locked( engine: &mut regorus::Engine, input: &NetworkInput, - request_path: &str, ) -> Result> { engine .set_input_json(&network_input_json(input).to_string()) @@ -772,37 +759,7 @@ fn query_middleware_chain_locked( if configs.is_empty() { return Ok(Vec::new()); } - let contexts_val = engine - .eval_rule("data.openshell.sandbox._matching_endpoint_contexts".into()) - .map_err(|e| miette::miette!("{e}"))?; - let contexts = parse_endpoint_contexts(&contexts_val); - let policies_val = engine - .eval_rule("data.openshell.sandbox.network_policies".into()) - .map_err(|e| miette::miette!("{e}"))?; - let Some(context) = select_endpoint_context(&contexts, request_path, &policies_val)? else { - return global_middleware_entries(&configs, &input.host, &HashSet::new()); - }; - let (policy_middleware, endpoint_middleware) = - middleware_for_endpoint_identity(&policies_val, context)?; - - let mut explicit = Vec::new(); - for name in policy_middleware.iter().chain(endpoint_middleware.iter()) { - if !explicit.contains(name) { - explicit.push(name.clone()); - } - } - let explicit_set: HashSet = explicit.iter().cloned().collect(); - let mut ordered = global_middleware_entries(&configs, &input.host, &explicit_set)?; - for name in explicit { - if !ordered.iter().any(|entry| entry.name == name) { - let config = configs - .iter() - .find(|config| get_str(config, "name").as_deref() == Some(name.as_str())) - .ok_or_else(|| miette::miette!("unknown middleware config '{name}'"))?; - ordered.push(chain_entry_from_value(config)?); - } - } - Ok(ordered) + global_middleware_entries(&configs, &input.host) } fn parse_middleware_configs(value: ®orus::Value) -> Result> { @@ -815,169 +772,37 @@ fn parse_middleware_configs(value: ®orus::Value) -> Result Vec { - let regorus::Value::Array(values) = value else { - return Vec::new(); - }; - values - .iter() - .filter_map(|value| { - let regorus::Value::Object(_) = value else { - return None; - }; - Some(MatchedEndpointContext { - policy_name: get_str(value, "policy").unwrap_or_default(), - endpoint_index: get_usize(value, "endpoint_index").unwrap_or_default(), - endpoint_path: get_str(value, "endpoint_path").unwrap_or_default(), - }) - }) - .collect() -} - -fn middleware_for_endpoint_identity( - policies: ®orus::Value, - context: &MatchedEndpointContext, -) -> Result<(Vec, Vec)> { - let policy = get_field(policies, &context.policy_name).ok_or_else(|| { - miette::miette!( - "matched endpoint policy '{}' was not found in OPA data", - context.policy_name - ) - })?; - let endpoint = get_array(policy, "endpoints") - .and_then(|endpoints| endpoints.get(context.endpoint_index)) - .ok_or_else(|| { - miette::miette!( - "matched endpoint {}[{}] was not found in OPA data", - context.policy_name, - context.endpoint_index - ) - })?; - Ok(( - get_str_array(policy, "middleware"), - get_str_array(endpoint, "middleware"), - )) -} - -fn select_endpoint_context<'a>( - contexts: &'a [MatchedEndpointContext], - request_path: &str, - policies: ®orus::Value, -) -> Result> { - let matching: Vec<_> = contexts - .iter() - .filter(|context| crate::l7::endpoint_path_matches(&context.endpoint_path, request_path)) - .map(|context| (endpoint_path_specificity(&context.endpoint_path), context)) - .collect(); - let Some(max_specificity) = matching.iter().map(|(specificity, _)| *specificity).max() else { - return Ok(None); - }; - let best: Vec<_> = matching - .into_iter() - .filter(|(specificity, _)| *specificity == max_specificity) - .map(|(_, context)| context) - .collect(); - if let Some((first, rest)) = best.split_first() { - let first_middleware = explicit_middleware_for_endpoint_identity(policies, first)?; - for context in rest { - if explicit_middleware_for_endpoint_identity(policies, context)? != first_middleware { - let matches = best - .iter() - .map(|context| { - format!( - "{}[{}] path={}", - context.policy_name, - context.endpoint_index, - if context.endpoint_path.is_empty() { - "" - } else { - context.endpoint_path.as_str() - } - ) - }) - .collect::>() - .join(", "); - return Err(miette::miette!( - "ambiguous middleware endpoint match for request path '{request_path}': {matches}" - )); - } - } - } - Ok(best.into_iter().next()) -} - -fn explicit_middleware_for_endpoint_identity( - policies: ®orus::Value, - context: &MatchedEndpointContext, -) -> Result> { - let (policy_middleware, endpoint_middleware) = - middleware_for_endpoint_identity(policies, context)?; - Ok(dedup_middleware_names( - policy_middleware.iter().chain(endpoint_middleware.iter()), - )) -} - -fn dedup_middleware_names<'a>(names: impl IntoIterator) -> Vec { - let mut deduped = Vec::new(); - for name in names { - if !deduped.contains(name) { - deduped.push(name.clone()); - } - } - deduped -} - -fn endpoint_path_specificity(path: &str) -> usize { - if path.is_empty() { - 0 - } else { - path.chars().filter(|c| *c != '*').count() - } -} - -fn global_middleware_entries( - configs: &[regorus::Value], - host: &str, - explicit: &HashSet, -) -> Result> { +fn global_middleware_entries(configs: &[regorus::Value], host: &str) -> Result> { let mut entries = Vec::new(); for config in configs { - let name = get_str(config, "name").unwrap_or_default(); - if explicit.contains(&name) { - continue; - } - if middleware_selector_matches(config, host) { + if middleware_selector_matches(config, host)? { entries.push(chain_entry_from_value(config)?); } } Ok(entries) } -fn middleware_selector_matches(config: ®orus::Value, host: &str) -> bool { +fn middleware_selector_matches(config: ®orus::Value, host: &str) -> Result { let Some(selector) = get_field(config, "endpoints") else { - return false; + return Ok(false); }; let include_patterns = get_str_array(selector, "include"); let exclude_patterns = get_str_array(selector, "exclude"); - let matches_include = !include_patterns.is_empty() - && include_patterns - .iter() - .any(|pattern| host_matches(pattern, host)); + let matches_include = include_patterns + .iter() + .try_fold(false, |matched, pattern| { + openshell_policy::middleware_host_matches(pattern, host) + .map(|matches| matched || matches) + .map_err(|error| miette::miette!(error)) + })?; let matches_exclude = exclude_patterns .iter() - .any(|pattern| host_matches(pattern, host)); - matches_include && !matches_exclude -} - -fn host_matches(pattern: &str, host: &str) -> bool { - if pattern == "*" || pattern == "**" { - return true; - } - if !pattern.contains('*') { - return pattern.eq_ignore_ascii_case(host); - } - glob::Pattern::new(&pattern.to_ascii_lowercase()) - .is_ok_and(|pattern| pattern.matches(&host.to_ascii_lowercase())) + .try_fold(false, |matched, pattern| { + openshell_policy::middleware_host_matches(pattern, host) + .map(|matches| matched || matches) + .map_err(|error| miette::miette!(error)) + })?; + Ok(matches_include && !matches_exclude) } fn chain_entry_from_value(value: ®orus::Value) -> Result { @@ -1003,25 +828,6 @@ fn get_field<'a>(val: &'a regorus::Value, key: &str) -> Option<&'a regorus::Valu } } -fn get_array<'a>(val: &'a regorus::Value, key: &str) -> Option<&'a [regorus::Value]> { - let regorus::Value::Array(values) = get_field(val, key)? else { - return None; - }; - Some(values) -} - -fn get_usize(val: ®orus::Value, key: &str) -> Option { - let value = get_field(val, key)?; - let regorus::Value::Number(number) = value else { - return None; - }; - let value = number.as_f64()?; - if !value.is_finite() || value.fract() != 0.0 || value < 0.0 { - return None; - } - format!("{value:.0}").parse::().ok() -} - fn regorus_value_to_struct(value: ®orus::Value) -> prost_types::Struct { let regorus::Value::Object(map) = value else { return prost_types::Struct::default(); @@ -1383,6 +1189,29 @@ fn validate_middleware_policies(data: &serde_json::Value) -> Vec { "middleware config '{name}' has invalid on_error '{on_error}'" )); } + + let Some(selector) = mw.get("endpoints") else { + errors.push(format!( + "middleware config '{name}' requires an endpoint selector" + )); + continue; + }; + let includes = json_string_array(selector.get("include")); + let excludes = json_string_array(selector.get("exclude")); + if includes.is_empty() { + errors.push(format!( + "middleware config '{name}' endpoint selector must include at least one host pattern" + )); + } + for pattern in includes.iter().chain(&excludes) { + if let Err(error) = + openshell_policy::middleware_host_matches(pattern, "validation.invalid") + { + errors.push(format!( + "middleware config '{name}' has invalid endpoint selector pattern '{pattern}': {error}" + )); + } + } } let Some(policies) = data @@ -1393,127 +1222,25 @@ fn validate_middleware_policies(data: &serde_json::Value) -> Vec { }; for (policy_name, policy) in policies { - let policy_middleware = json_string_array(policy.get("middleware")); - for name in &policy_middleware { - if !names.contains(name) { - errors.push(format!( - "network policy '{policy_name}' references unknown middleware config '{name}'" - )); - } - } for endpoint in policy .get("endpoints") .and_then(serde_json::Value::as_array) .map_or(&[][..], Vec::as_slice) { - let endpoint_middleware = json_string_array(endpoint.get("middleware")); - for name in &endpoint_middleware { - if !names.contains(name) { - errors.push(format!( - "network policy '{policy_name}' endpoint references unknown middleware config '{name}'" - )); - } - } let tls_skip = endpoint .get("tls") .and_then(serde_json::Value::as_str) .is_some_and(|tls| tls == "skip"); - if tls_skip && (!policy_middleware.is_empty() || !endpoint_middleware.is_empty()) { - errors.push(format!( - "network policy '{policy_name}' attaches middleware to a tls: skip endpoint" - )); - } if tls_skip && global_selector_matches_any_middleware(middlewares, endpoint) { errors.push(format!( "network policy '{policy_name}' tls: skip endpoint matches a global middleware selector" )); } } - validate_ambiguous_middleware_endpoints( - policy_name, - policy, - &policy_middleware, - &mut errors, - ); } errors } -fn validate_ambiguous_middleware_endpoints( - policy_name: &str, - policy: &serde_json::Value, - policy_middleware: &[String], - errors: &mut Vec, -) { - let endpoints = policy - .get("endpoints") - .and_then(serde_json::Value::as_array) - .map_or(&[][..], Vec::as_slice); - let mut seen: Vec<(usize, MiddlewareEndpointKey, Vec)> = Vec::new(); - for (index, endpoint) in endpoints.iter().enumerate() { - let key = middleware_endpoint_key(endpoint); - let endpoint_middleware = json_string_array(endpoint.get("middleware")); - let chain = - dedup_middleware_names(policy_middleware.iter().chain(endpoint_middleware.iter())); - for (previous_index, previous_key, previous_chain) in &seen { - if previous_key == &key && previous_chain != &chain { - errors.push(format!( - "network policy '{policy_name}' endpoints[{previous_index}] and endpoints[{index}] have equivalent middleware selection keys ({key}) but different middleware chains" - )); - } - } - seen.push((index, key, chain)); - } -} - -#[derive(Debug, PartialEq, Eq)] -struct MiddlewareEndpointKey { - host: String, - ports: Vec, - path: String, -} - -impl std::fmt::Display for MiddlewareEndpointKey { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "host={} ports={:?} path={}", - if self.host.is_empty() { - "" - } else { - self.host.as_str() - }, - self.ports, - if self.path.is_empty() { - "" - } else { - self.path.as_str() - } - ) - } -} - -fn middleware_endpoint_key(endpoint: &serde_json::Value) -> MiddlewareEndpointKey { - let host = endpoint - .get("host") - .and_then(serde_json::Value::as_str) - .unwrap_or_default() - .to_ascii_lowercase(); - let mut ports: Vec = endpoint - .get("ports") - .and_then(serde_json::Value::as_array) - .map(|ports| ports.iter().filter_map(serde_json::Value::as_u64).collect()) - .unwrap_or_default(); - ports.sort_unstable(); - ports.dedup(); - let path = endpoint - .get("path") - .and_then(serde_json::Value::as_str) - .unwrap_or_default() - .to_string(); - MiddlewareEndpointKey { host, ports, path } -} - fn json_string_array(value: Option<&serde_json::Value>) -> Vec { value .and_then(serde_json::Value::as_array) @@ -1542,8 +1269,12 @@ fn global_selector_matches_any_middleware( let includes = json_string_array(selector.get("include")); let excludes = json_string_array(selector.get("exclude")); !includes.is_empty() - && includes.iter().any(|pattern| host_matches(pattern, host)) - && !excludes.iter().any(|pattern| host_matches(pattern, host)) + && includes.iter().any(|pattern| { + openshell_policy::middleware_host_matches(pattern, host).unwrap_or(false) + }) + && !excludes.iter().any(|pattern| { + openshell_policy::middleware_host_matches(pattern, host).unwrap_or(false) + }) }) } @@ -1908,9 +1639,6 @@ fn proto_to_opa_data_json(proto: &ProtoSandboxPolicy, entrypoint_pid: u32) -> St allow_all_known_mcp_methods.into(); } } - if !e.middleware.is_empty() { - ep["middleware"] = e.middleware.clone().into(); - } ep }) .collect(); @@ -1936,14 +1664,11 @@ fn proto_to_opa_data_json(proto: &ProtoSandboxPolicy, entrypoint_pid: u32) -> St entries }) .collect(); - let mut policy = serde_json::json!({ + let policy = serde_json::json!({ "name": rule.name, "endpoints": endpoints, "binaries": binaries, }); - if !rule.middleware.is_empty() { - policy["middleware"] = rule.middleware.clone().into(); - } (key.clone(), policy) }) .collect(); @@ -2058,7 +1783,6 @@ mod tests { path: "/usr/local/bin/claude".to_string(), ..Default::default() }], - ..Default::default() }, ); network_policies.insert( @@ -2074,7 +1798,6 @@ mod tests { path: "/usr/bin/glab".to_string(), ..Default::default() }], - ..Default::default() }, ); ProtoSandboxPolicy { @@ -3417,7 +3140,6 @@ network_policies: path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }, ); @@ -3490,7 +3212,6 @@ network_policies: path: "/usr/bin/curl".to_string(), ..Default::default() }], - middleware: vec![], }, ); @@ -3564,7 +3285,6 @@ network_policies: path: "/usr/bin/curl".to_string(), ..Default::default() }], - middleware: vec![], }, ); @@ -4443,7 +4163,6 @@ network_policies: path: "/usr/bin/node".to_string(), ..Default::default() }], - ..Default::default() }, ); let proto = ProtoSandboxPolicy { @@ -4502,7 +4221,6 @@ network_policies: path: "/usr/bin/node".to_string(), ..Default::default() }], - ..Default::default() }, ); let proto = ProtoSandboxPolicy { @@ -4562,7 +4280,6 @@ network_policies: path: "/usr/local/bin/claude".to_string(), ..Default::default() }], - middleware: vec![], }, ); let proto = ProtoSandboxPolicy { @@ -4624,7 +4341,6 @@ network_policies: path: "/usr/local/bin/aws".to_string(), ..Default::default() }], - middleware: vec![], }, ); let proto = ProtoSandboxPolicy { @@ -4685,7 +4401,6 @@ network_policies: path: "/usr/bin/node".to_string(), ..Default::default() }], - ..Default::default() }, ); let proto = ProtoSandboxPolicy { @@ -5636,7 +5351,6 @@ process: ..Default::default() }], binaries: vec![proposal_binary], - ..Default::default() }, ); let proto = ProtoSandboxPolicy { @@ -5692,7 +5406,6 @@ process: path: "/usr/bin/python".to_string(), ..Default::default() }], - ..Default::default() }, ); let proto = ProtoSandboxPolicy { @@ -5764,7 +5477,6 @@ process: path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }, ); let proto = ProtoSandboxPolicy { @@ -5996,7 +5708,6 @@ network_policies: path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }, ); let proto = ProtoSandboxPolicy { @@ -6701,7 +6412,6 @@ network_policies: path: "/usr/bin/python3".to_string(), ..Default::default() }], - ..Default::default() }, ); @@ -7099,7 +6809,7 @@ network_policies: } #[test] - fn middleware_chain_orders_global_policy_endpoint_once() { + fn middleware_chain_uses_matching_selector_declaration_order() { let data = r#" network_middlewares: - name: global-redactor @@ -7108,18 +6818,20 @@ network_middlewares: include: ["api.example.com"] - name: policy-redactor middleware: openshell/secrets + endpoints: + include: ["api.example.com"] - name: endpoint-redactor middleware: openshell/secrets + endpoints: + include: ["api.example.com"] network_policies: api: name: api - middleware: ["global-redactor", "policy-redactor"] endpoints: - host: api.example.com port: 443 protocol: rest enforcement: enforce - middleware: ["policy-redactor", "endpoint-redactor"] rules: - allow: { method: POST, path: "/v1/**" } binaries: @@ -7135,7 +6847,7 @@ network_policies: cmdline_paths: vec![], }; let (chain, _) = engine - .query_middleware_chain_with_generation(&input, "/v1/messages") + .query_middleware_chain_with_generation(&input) .unwrap(); let names: Vec<_> = chain.iter().map(|entry| entry.name.as_str()).collect(); assert_eq!( @@ -7144,63 +6856,9 @@ network_policies: ); } - #[test] - fn middleware_validation_rejects_ambiguous_duplicate_endpoint_middleware() { - let data = r#" -network_middlewares: - - name: first-redactor - middleware: openshell/secrets - - name: second-redactor - middleware: openshell/secrets -network_policies: - api: - name: api - endpoints: - - host: api.example.com - port: 443 - protocol: rest - enforcement: enforce - middleware: ["first-redactor"] - access: full - - host: api.example.com - port: 443 - protocol: rest - enforcement: enforce - middleware: ["second-redactor"] - access: full - binaries: - - { path: /usr/bin/curl } -"#; - let err = match OpaEngine::from_strings(TEST_POLICY, data) { - Ok(_) => panic!("equivalent endpoints with different middleware should be invalid"), - Err(err) => err, - }; - assert!( - err.to_string() - .contains("equivalent middleware selection keys"), - "{err:?}" - ); - } - #[test] fn middleware_policy_validation_rejects_bad_configs() { let cases = [ - ( - "missing reference", - r#" -network_middlewares: - - name: redactor - middleware: openshell/secrets -network_policies: - api: - middleware: ["missing"] - endpoints: - - { host: api.example.com, port: 443 } - binaries: - - { path: /usr/bin/curl } -"#, - "unknown middleware config 'missing'", - ), ( "invalid on_error", r#" @@ -7208,6 +6866,8 @@ network_middlewares: - name: redactor middleware: openshell/secrets on_error: maybe + endpoints: + include: ["api.example.com"] "#, "invalid on_error", ), @@ -7217,8 +6877,12 @@ network_middlewares: network_middlewares: - name: redactor middleware: openshell/secrets + endpoints: + include: ["api.example.com"] - name: redactor middleware: openshell/secrets + endpoints: + include: ["api.example.com"] "#, "duplicate middleware config 'redactor'", ), @@ -7228,22 +6892,45 @@ network_middlewares: network_middlewares: - name: sigv4 middleware: openshell/sigv4 + endpoints: + include: ["api.example.com"] "#, "unsupported built-in", ), ( - "tls skip attachment", + "missing selector", + r#" +network_middlewares: + - name: redactor + middleware: openshell/secrets +"#, + "requires an endpoint selector", + ), + ( + "malformed selector", + r#" +network_middlewares: + - name: redactor + middleware: openshell/secrets + endpoints: + include: ["api[.example.com"] +"#, + "invalid host pattern", + ), + ( + "tls skip selector", r#" network_middlewares: - name: redactor middleware: openshell/secrets + endpoints: + include: ["api.example.com"] network_policies: api: endpoints: - host: api.example.com port: 443 tls: skip - middleware: ["redactor"] binaries: - { path: /usr/bin/curl } "#, @@ -7263,6 +6950,29 @@ network_policies: } } + #[test] + fn from_proto_revalidates_middleware_policy() { + let mut policy = openshell_policy::restrictive_default_policy(); + policy + .network_middlewares + .push(openshell_core::proto::NetworkMiddlewareConfig { + name: "redactor".into(), + middleware: "openshell/secrets".into(), + endpoints: Some(openshell_core::proto::MiddlewareEndpointSelector { + include: vec!["api[.example.com".into()], + exclude: Vec::new(), + }), + ..Default::default() + }); + + let error = OpaEngine::from_proto(&policy) + .err() + .expect("supervisor must reject invalid effective middleware policy") + .to_string(); + assert!(error.contains("policy validation failed"), "{error}"); + assert!(error.contains("invalid host pattern"), "{error}"); + } + #[test] fn l7_head_denied_when_only_post_allowed() { let engine = OpaEngine::from_strings( diff --git a/crates/openshell-supervisor-network/src/policy_local.rs b/crates/openshell-supervisor-network/src/policy_local.rs index fa8029c72..3cbc31502 100644 --- a/crates/openshell-supervisor-network/src/policy_local.rs +++ b/crates/openshell-supervisor-network/src/policy_local.rs @@ -1047,7 +1047,6 @@ fn network_rule_from_json( name: rule.name.unwrap_or_default(), endpoints, binaries, - middleware: Vec::new(), }) } @@ -1134,7 +1133,6 @@ fn network_endpoint_from_json( credential_signing: String::new(), signing_service: String::new(), signing_region: String::new(), - middleware: Vec::new(), }) } @@ -1831,7 +1829,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }), ..Default::default() }; @@ -1856,7 +1853,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() } } @@ -1920,7 +1916,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() })); }) }; diff --git a/crates/openshell-supervisor-network/src/proxy.rs b/crates/openshell-supervisor-network/src/proxy.rs index f8310fbdc..afec98666 100644 --- a/crates/openshell-supervisor-network/src/proxy.rs +++ b/crates/openshell-supervisor-network/src/proxy.rs @@ -4194,7 +4194,7 @@ async fn handle_forward_proxy( cmdline_paths: decision.cmdline_paths.clone(), }; let (chain, generation) = - opa_engine.query_middleware_chain_with_generation(&middleware_input, middleware_path)?; + opa_engine.query_middleware_chain_with_generation(&middleware_input)?; if generation != forward_generation_guard.captured_generation() { emit_l7_tunnel_close_after_policy_change( &host_lc, @@ -4224,10 +4224,11 @@ async fn handle_forward_proxy( &upstream_target, forward_request_bytes, )?; - forward_request_bytes = match crate::l7::relay::apply_middleware_chain( + forward_request_bytes = match crate::l7::relay::apply_middleware_chain_for_scheme( request, client, &l7_ctx, + &scheme, chain, &forward_generation_guard, ) diff --git a/proto/middleware.proto b/proto/middleware.proto index d5d2ad48d..2944227d8 100644 --- a/proto/middleware.proto +++ b/proto/middleware.proto @@ -25,6 +25,8 @@ message MiddlewareBinding { string id = 1; string operation = 2; string phase = 3; + // Maximum request or replacement body this binding can process. + uint64 max_body_bytes = 4; } message ValidateConfigRequest { diff --git a/proto/sandbox.proto b/proto/sandbox.proto index a73d762e5..04cbd6776 100644 --- a/proto/sandbox.proto +++ b/proto/sandbox.proto @@ -63,11 +63,9 @@ message NetworkPolicyRule { repeated NetworkEndpoint endpoints = 2; // Allowed binary identities. repeated NetworkBinary binaries = 3; - // Ordered middleware configs applied to every endpoint in this policy. - repeated string middleware = 4; } -// A reusable middleware config referenced by network policies/endpoints. +// A reusable middleware config selected for admitted egress by host. message NetworkMiddlewareConfig { // Policy-local config name. string name = 1; @@ -77,7 +75,7 @@ message NetworkMiddlewareConfig { google.protobuf.Struct config = 3; // Failure behavior: "fail_closed" (default) or "fail_open". string on_error = 4; - // Optional global endpoint selector for this config. + // Host selector controlling which admitted destinations use this config. MiddlewareEndpointSelector endpoints = 5; } @@ -168,8 +166,6 @@ message NetworkEndpoint { uint32 json_rpc_max_body_bytes = 22; // MCP-only policy and inspection options. Only used when protocol is "mcp". McpOptions mcp = 23; - // Ordered middleware configs applied to this endpoint after policy-level middleware. - repeated string middleware = 24; } // MCP options are grouped so MCP-specific policy can grow without adding more From bb08a337952881a2d7a379ad87ef35a9ae90437a Mon Sep 17 00:00:00 2001 From: Piotr Mlocek Date: Mon, 29 Jun 2026 15:28:14 -0700 Subject: [PATCH 07/27] feat(supervisor-middleware): support external services Signed-off-by: Piotr Mlocek --- Cargo.lock | 3 + architecture/gateway.md | 8 + architecture/sandbox.md | 10 + crates/openshell-core/src/grpc_client.rs | 41 +- crates/openshell-sandbox/Cargo.toml | 1 + crates/openshell-sandbox/src/lib.rs | 68 +- crates/openshell-server/Cargo.toml | 1 + crates/openshell-server/src/config_file.rs | 55 ++ crates/openshell-server/src/grpc/policy.rs | 71 ++- crates/openshell-server/src/grpc/sandbox.rs | 1 + crates/openshell-server/src/lib.rs | 25 + crates/openshell-server/src/middleware.rs | 43 ++ .../Cargo.toml | 5 +- .../src/lib.rs | 601 ++++++++++++++++-- .../src/remote.rs | 91 +++ .../src/l7/relay.rs | 266 ++++---- .../openshell-supervisor-network/src/opa.rs | 34 +- .../openshell-supervisor-network/src/proxy.rs | 2 + docs/reference/gateway-config.mdx | 28 + docs/reference/policy-schema.mdx | 32 +- docs/sandboxes/policies.mdx | 36 +- proto/sandbox.proto | 16 + 22 files changed, 1246 insertions(+), 192 deletions(-) create mode 100644 crates/openshell-server/src/middleware.rs create mode 100644 crates/openshell-supervisor-middleware/src/remote.rs diff --git a/Cargo.lock b/Cargo.lock index 0634d12e9..6819aab32 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3833,6 +3833,7 @@ dependencies = [ "openshell-core", "openshell-ocsf", "openshell-policy", + "openshell-supervisor-middleware", "openshell-supervisor-network", "openshell-supervisor-process", "rustls", @@ -3887,6 +3888,7 @@ dependencies = [ "openshell-providers", "openshell-router", "openshell-server-macros", + "openshell-supervisor-middleware", "petname", "pin-project-lite", "prost", @@ -3938,6 +3940,7 @@ dependencies = [ "prost-types", "regex", "tokio", + "tokio-stream", "tonic", ] diff --git a/architecture/gateway.md b/architecture/gateway.md index d873b2a10..ba8437b7f 100644 --- a/architecture/gateway.md +++ b/architecture/gateway.md @@ -271,6 +271,14 @@ config path. A gateway-global policy can override sandbox-scoped policy. The sandbox supervisor polls for config revisions and hot-reloads dynamic policy when the policy engine accepts the update. +External supervisor middleware registration is operator-owned gateway +configuration. At startup the gateway connects to each service, validates its +described bindings and operator body limit, and rejects duplicate binding IDs. +Before persisting a policy, the gateway asks each selected implementation to +validate its config. The effective sandbox config contains only the registered +services required by that policy; supervisors invoke those services directly on +the request path. + Provider credential expiry is enforced during gateway-to-sandbox credential resolution and again by the sandbox placeholder resolver. This keeps expired credentials from resolving even when a running sandbox still has retained diff --git a/architecture/sandbox.md b/architecture/sandbox.md index 580d8f96d..deec00f32 100644 --- a/architecture/sandbox.md +++ b/architecture/sandbox.md @@ -66,6 +66,14 @@ matchers; generic JSON-RPC rules match only the method. JSON-RPC responses and server-to-client MCP messages on response or SSE streams are relayed but are not currently parsed for policy enforcement. +For admitted HTTP requests, the proxy can run an ordered supervisor middleware +chain before credential injection. Host selectors choose the chain independently +of the network rule that admitted the request. Built-ins run in-process; +operator-registered external services are called directly from the supervisor +over the common middleware gRPC contract. The gateway validates external +service capabilities and policy-owned config before delivery. Supervisors keep +the last-known-good service registry when a live config reload fails. + `https://inference.local` is special. It bypasses OPA network policy and is handled by the inference interception path: @@ -176,6 +184,8 @@ quickly. - If gateway config polling fails, the sandbox keeps its last-known-good policy. - If a live policy update is invalid, the supervisor rejects it and keeps the current policy. +- If an external middleware call fails, the selected config's `on_error` + behavior decides whether to deny the request or continue without that stage. - Existing raw byte streams are connection scoped. Dynamic policy changes apply to new connections or the next parsed HTTP request where the proxy can safely re-evaluate. diff --git a/crates/openshell-core/src/grpc_client.rs b/crates/openshell-core/src/grpc_client.rs index 96158a1d1..57f72dca6 100644 --- a/crates/openshell-core/src/grpc_client.rs +++ b/crates/openshell-core/src/grpc_client.rs @@ -24,11 +24,11 @@ use std::time::{Duration, SystemTime, UNIX_EPOCH}; use crate::proto::{ DenialSummary, GetDraftPolicyRequest, GetInferenceBundleRequest, GetInferenceBundleResponse, - GetSandboxConfigRequest, GetSandboxProviderEnvironmentRequest, IssueSandboxTokenRequest, - NetworkActivitySummary, PolicyChunk, PolicySource, PolicyStatus, RefreshSandboxTokenRequest, - ReportPolicyStatusRequest, SandboxPolicy as ProtoSandboxPolicy, SubmitPolicyAnalysisRequest, - SubmitPolicyAnalysisResponse, UpdateConfigRequest, inference_client::InferenceClient, - open_shell_client::OpenShellClient, + GetSandboxConfigRequest, GetSandboxConfigResponse, GetSandboxProviderEnvironmentRequest, + IssueSandboxTokenRequest, NetworkActivitySummary, PolicyChunk, PolicySource, PolicyStatus, + RefreshSandboxTokenRequest, ReportPolicyStatusRequest, SandboxPolicy as ProtoSandboxPolicy, + SubmitPolicyAnalysisRequest, SubmitPolicyAnalysisResponse, UpdateConfigRequest, + inference_client::InferenceClient, open_shell_client::OpenShellClient, }; use crate::sandbox_env; use miette::{IntoDiagnostic, Result, WrapErr}; @@ -573,19 +573,36 @@ pub async fn fetch_policy(endpoint: &str, sandbox_id: &str) -> Result Result { + debug!(endpoint = %endpoint, sandbox_id = %sandbox_id, "Connecting to OpenShell server"); + let mut client = connect(endpoint).await?; + fetch_sandbox_config_with_client(&mut client, sandbox_id).await +} + +async fn fetch_sandbox_config_with_client( client: &mut OpenShellClient, sandbox_id: &str, -) -> Result> { - let response = client +) -> Result { + client .get_sandbox_config(GetSandboxConfigRequest { sandbox_id: sandbox_id.to_string(), }) .await - .into_diagnostic()?; + .map(tonic::Response::into_inner) + .into_diagnostic() +} - let inner = response.into_inner(); +/// Fetch sandbox policy using an existing client connection. +async fn fetch_policy_with_client( + client: &mut OpenShellClient, + sandbox_id: &str, +) -> Result> { + let inner = fetch_sandbox_config_with_client(client, sandbox_id).await?; // version 0 with no policy means the sandbox was created without one. if inner.version == 0 && inner.policy.is_none() { @@ -711,6 +728,7 @@ pub struct SettingsPollResult { /// When `policy_source` is `Global`, the version of the global policy revision. pub global_policy_version: u32, pub provider_env_revision: u64, + pub external_middleware: Vec, } pub struct ProviderEnvironmentResult { @@ -755,6 +773,7 @@ impl CachedOpenShellClient { settings: inner.settings, global_policy_version: inner.global_policy_version, provider_env_revision: inner.provider_env_revision, + external_middleware: inner.external_middleware, }) } diff --git a/crates/openshell-sandbox/Cargo.toml b/crates/openshell-sandbox/Cargo.toml index 086dbe02c..d3c3e7108 100644 --- a/crates/openshell-sandbox/Cargo.toml +++ b/crates/openshell-sandbox/Cargo.toml @@ -19,6 +19,7 @@ openshell-core = { path = "../openshell-core", default-features = false } openshell-ocsf = { path = "../openshell-ocsf" } openshell-policy = { path = "../openshell-policy" } openshell-supervisor-network = { path = "../openshell-supervisor-network" } +openshell-supervisor-middleware = { path = "../openshell-supervisor-middleware" } openshell-supervisor-process = { path = "../openshell-supervisor-process" } # Async runtime diff --git a/crates/openshell-sandbox/src/lib.rs b/crates/openshell-sandbox/src/lib.rs index 53b1eba58..30da1c292 100644 --- a/crates/openshell-sandbox/src/lib.rs +++ b/crates/openshell-sandbox/src/lib.rs @@ -1409,12 +1409,12 @@ async fn load_policy( endpoint = %endpoint, "Fetching sandbox policy via gRPC" ); - let proto_policy = grpc_retry("Policy fetch", || { - openshell_core::grpc_client::fetch_policy(endpoint, id) + let mut sandbox_config = grpc_retry("Policy fetch", || { + openshell_core::grpc_client::fetch_sandbox_config(endpoint, id) }) .await?; - let mut proto_policy = if let Some(p) = proto_policy { + let mut proto_policy = if let Some(p) = sandbox_config.policy.take() { p } else { // No policy configured on the server. Discover from disk or @@ -1442,7 +1442,7 @@ async fn load_policy( // Sync and re-fetch over a single connection to avoid extra // TLS handshakes. - grpc_retry("Policy discovery sync", || { + let synced = grpc_retry("Policy discovery sync", || { openshell_core::grpc_client::discover_and_sync_policy( endpoint, id, @@ -1450,7 +1450,12 @@ async fn load_policy( &discovered, ) }) - .await? + .await?; + sandbox_config = grpc_retry("Policy refetch after discovery", || { + openshell_core::grpc_client::fetch_sandbox_config(endpoint, id) + }) + .await?; + sandbox_config.policy.take().unwrap_or(synced) }; // Ensure baseline filesystem paths are present for proxy-mode @@ -1476,7 +1481,14 @@ async fn load_policy( // container hasn't started yet. After the entrypoint spawns, the // engine is rebuilt with the real PID for symlink resolution. info!("Creating OPA engine from proto policy data"); - let opa_engine = Some(Arc::new(OpaEngine::from_proto(&proto_policy)?)); + let engine = OpaEngine::from_proto(&proto_policy)?; + let middleware_registry = + openshell_supervisor_middleware::MiddlewareRegistry::connect_external( + sandbox_config.external_middleware, + ) + .await?; + engine.replace_middleware_registry(middleware_registry)?; + let opa_engine = Some(Arc::new(engine)); let policy = SandboxPolicy::try_from(proto_policy.clone())?; return Ok((policy, opa_engine, Some(proto_policy))); @@ -1626,6 +1638,7 @@ async fn run_policy_poll_loop(ctx: PolicyPollLoopContext) -> Result<()> { let mut current_config_revision: u64 = 0; let mut current_provider_env_revision: u64 = ctx.provider_credentials.snapshot().revision; let mut current_policy_hash = String::new(); + let mut current_external_middleware = Vec::new(); let mut current_settings: std::collections::HashMap< String, openshell_core::proto::EffectiveSetting, @@ -1637,6 +1650,7 @@ async fn run_policy_poll_loop(ctx: PolicyPollLoopContext) -> Result<()> { apply_ocsf_json_setting(&ctx.ocsf_enabled, &result.settings); current_config_revision = result.config_revision; current_policy_hash = result.policy_hash.clone(); + current_external_middleware = result.external_middleware; current_settings = result.settings; debug!( config_revision = current_config_revision, @@ -1666,6 +1680,7 @@ async fn run_policy_poll_loop(ctx: PolicyPollLoopContext) -> Result<()> { } let policy_changed = result.policy_hash != current_policy_hash; + let middleware_changed = result.external_middleware != current_external_middleware; // Log which settings changed. log_setting_changes(¤t_settings, &result.settings); @@ -1724,6 +1739,47 @@ async fn run_policy_poll_loop(ctx: PolicyPollLoopContext) -> Result<()> { } } + if middleware_changed { + match openshell_supervisor_middleware::MiddlewareRegistry::connect_external( + result.external_middleware.clone(), + ) + .await + .and_then(|registry| ctx.opa_engine.replace_middleware_registry(registry)) + { + Ok(()) => { + current_external_middleware = result.external_middleware.clone(); + ocsf_emit!( + ConfigStateChangeBuilder::new(ocsf_ctx()) + .severity(SeverityId::Informational) + .status(StatusId::Success) + .state(StateId::Enabled, "loaded") + .unmapped( + "external_middleware_count", + serde_json::json!(current_external_middleware.len()) + ) + .message(format!( + "External middleware registry reloaded [service_count:{}]", + current_external_middleware.len() + )) + .build() + ); + } + Err(error) => { + ocsf_emit!( + ConfigStateChangeBuilder::new(ocsf_ctx()) + .severity(SeverityId::Medium) + .status(StatusId::Failure) + .state(StateId::Other, "failed") + .message(format!( + "External middleware registry reload failed, keeping last-known-good registry [error:{error}]" + )) + .build() + ); + continue; + } + } + } + // Only reload OPA when the policy payload actually changed. if policy_changed { let Some(policy) = result.policy.as_ref() else { diff --git a/crates/openshell-server/Cargo.toml b/crates/openshell-server/Cargo.toml index b5c9b34d7..fafc72ba7 100644 --- a/crates/openshell-server/Cargo.toml +++ b/crates/openshell-server/Cargo.toml @@ -26,6 +26,7 @@ openshell-prover = { path = "../openshell-prover" } openshell-providers = { path = "../openshell-providers" } openshell-router = { path = "../openshell-router" } openshell-server-macros = { path = "../openshell-server-macros" } +openshell-supervisor-middleware = { path = "../openshell-supervisor-middleware" } # Kubernetes client (used by the `generate-certs` subcommand) kube = { workspace = true } diff --git a/crates/openshell-server/src/config_file.rs b/crates/openshell-server/src/config_file.rs index b65b5f3b0..13c7e9ebb 100644 --- a/crates/openshell-server/src/config_file.rs +++ b/crates/openshell-server/src/config_file.rs @@ -25,6 +25,7 @@ use std::net::SocketAddr; use std::path::{Path, PathBuf}; use openshell_core::config::ComputeDriverKind; +use openshell_core::proto::ExternalMiddlewareService; use openshell_core::{GatewayAuthConfig, GatewayJwtConfig, MtlsAuthConfig, OidcConfig, TlsConfig}; use serde::{Deserialize, Serialize}; @@ -151,6 +152,12 @@ pub struct GatewayFileSection { #[serde(default)] pub gateway_jwt: Option, + // ── Supervisor middleware ───────────────────────────────────────────── + /// Statically registered external middleware services. Registration is + /// operator-owned and changes require a gateway restart. + #[serde(default)] + pub middleware: Vec, + // ── Disallowed-in-file fields ──────────────────────────────────────── // // Captured so we can produce a friendly "set this via env/CLI instead" @@ -160,6 +167,32 @@ pub struct GatewayFileSection { pub database_url: Option, } +/// One `[[openshell.gateway.middleware]]` external middleware registration. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(deny_unknown_fields)] +pub struct MiddlewareServiceFileConfig { + /// Operator-facing name used for diagnostics. + pub name: String, + /// Plaintext gRPC endpoint reachable by the gateway and supervisors. + pub endpoint: String, + /// Required explicit opt-in to the local-development-only insecure mode. + #[serde(default)] + pub allow_insecure: bool, + /// Operator-owned body limit for every binding exposed by this service. + pub max_body_bytes: u64, +} + +impl From<&MiddlewareServiceFileConfig> for ExternalMiddlewareService { + fn from(config: &MiddlewareServiceFileConfig) -> Self { + Self { + name: config.name.clone(), + endpoint: config.endpoint.clone(), + allow_insecure: config.allow_insecure, + max_body_bytes: config.max_body_bytes, + } + } +} + #[derive(Debug, thiserror::Error)] pub enum ConfigFileError { #[error("failed to read gateway config file '{}': {source}", path.display())] @@ -401,6 +434,28 @@ allow_unauthenticated_users = true assert!(auth.allow_unauthenticated_users); } + #[test] + fn parses_external_middleware_registration() { + let toml = r#" +[[openshell.gateway.middleware]] +name = "local-guard" +endpoint = "http://127.0.0.1:50051" +allow_insecure = true +max_body_bytes = 262144 +"#; + let tmp = write_tmp(toml); + let file = load(tmp.path()).expect("valid middleware registration parses"); + assert_eq!( + file.openshell.gateway.middleware, + vec![MiddlewareServiceFileConfig { + name: "local-guard".into(), + endpoint: "http://127.0.0.1:50051".into(), + allow_insecure: true, + max_body_bytes: 262_144, + }] + ); + } + #[test] fn rejects_database_url_in_file() { let toml = r#" diff --git a/crates/openshell-server/src/grpc/policy.rs b/crates/openshell-server/src/grpc/policy.rs index d3bc213ba..c29de71ff 100644 --- a/crates/openshell-server/src/grpc/policy.rs +++ b/crates/openshell-server/src/grpc/policy.rs @@ -1218,8 +1218,27 @@ pub(super) async fn handle_get_sandbox_config( } } + if let Some(policy) = policy.as_ref() { + state + .middleware_registry + .ensure_policy_bindings_registered(policy) + .map_err(|error| { + Status::failed_precondition(format!( + "effective policy middleware registration is invalid: {error}" + )) + })?; + } + let settings = merge_effective_settings(&global_settings, &sandbox_settings)?; - let config_revision = compute_config_revision(policy.as_ref(), &settings, policy_source); + let external_middleware = state + .middleware_registry + .required_external_services(policy.as_ref()); + let config_revision = compute_config_revision( + policy.as_ref(), + &settings, + policy_source, + &external_middleware, + ); let provider_env_revision = compute_provider_env_revision(state.store.as_ref(), &sandbox_provider_names).await?; @@ -1232,6 +1251,7 @@ pub(super) async fn handle_get_sandbox_config( policy_source: policy_source.into(), global_policy_version, provider_env_revision, + external_middleware, })) } @@ -1510,6 +1530,8 @@ async fn handle_update_config_inner( openshell_policy::ensure_sandbox_process_identity(&mut new_policy); validate_no_reserved_provider_policy_keys(&new_policy)?; validate_policy_safety(&new_policy)?; + crate::middleware::validate_policy(state.middleware_registry.as_ref(), &new_policy) + .await?; let payload = new_policy.encode_to_vec(); let hash = deterministic_policy_hash(&new_policy); @@ -1827,9 +1849,11 @@ async fn handle_update_config_inner( validate_no_reserved_provider_policy_keys(&new_policy)?; } + validate_policy_safety(&new_policy)?; + crate::middleware::validate_policy(state.middleware_registry.as_ref(), &new_policy).await?; + if let Some(baseline_policy) = spec.policy.as_ref() { validate_static_fields_unchanged(baseline_policy, &new_policy)?; - validate_policy_safety(&new_policy)?; } else { // Backfill spec.policy using CAS (first-time policy discovery) let _sandbox_sync_guard = state.compute.sandbox_sync_guard().await; @@ -3120,6 +3144,7 @@ fn compute_config_revision( policy: Option<&ProtoSandboxPolicy>, settings: &HashMap, policy_source: PolicySource, + external_middleware: &[openshell_core::proto::ExternalMiddlewareService], ) -> u64 { let mut hasher = Sha256::new(); hasher.update((policy_source as i32).to_le_bytes()); @@ -3152,6 +3177,11 @@ fn compute_config_revision( } } } + let mut middleware = external_middleware.iter().collect::>(); + middleware.sort_by(|left, right| left.name.cmp(&right.name)); + for service in middleware { + hasher.update(service.encode_to_vec()); + } let digest = hasher.finalize(); let mut bytes = [0_u8; 8]; @@ -8989,7 +9019,7 @@ mod tests { }, ); - let rev_a = compute_config_revision(Some(&policy), &settings, PolicySource::Sandbox); + let rev_a = compute_config_revision(Some(&policy), &settings, PolicySource::Sandbox, &[]); settings.insert( "mode".to_string(), EffectiveSetting { @@ -8999,7 +9029,7 @@ mod tests { scope: SettingScope::Sandbox.into(), }, ); - let rev_b = compute_config_revision(Some(&policy), &settings, PolicySource::Sandbox); + let rev_b = compute_config_revision(Some(&policy), &settings, PolicySource::Sandbox, &[]); assert_ne!(rev_a, rev_b); } @@ -9264,8 +9294,8 @@ mod tests { }, ); - let rev_a = compute_config_revision(Some(&policy), &settings, PolicySource::Sandbox); - let rev_b = compute_config_revision(Some(&policy), &settings, PolicySource::Sandbox); + let rev_a = compute_config_revision(Some(&policy), &settings, PolicySource::Sandbox, &[]); + let rev_b = compute_config_revision(Some(&policy), &settings, PolicySource::Sandbox, &[]); assert_eq!(rev_a, rev_b); } @@ -9281,8 +9311,8 @@ mod tests { }; let settings = HashMap::new(); - let rev_a = compute_config_revision(Some(&policy_a), &settings, PolicySource::Sandbox); - let rev_b = compute_config_revision(Some(&policy_b), &settings, PolicySource::Sandbox); + let rev_a = compute_config_revision(Some(&policy_a), &settings, PolicySource::Sandbox, &[]); + let rev_b = compute_config_revision(Some(&policy_b), &settings, PolicySource::Sandbox, &[]); assert_ne!(rev_a, rev_b); } @@ -9291,11 +9321,28 @@ mod tests { let policy = ProtoSandboxPolicy::default(); let settings = HashMap::new(); - let rev_a = compute_config_revision(Some(&policy), &settings, PolicySource::Sandbox); - let rev_b = compute_config_revision(Some(&policy), &settings, PolicySource::Global); + let rev_a = compute_config_revision(Some(&policy), &settings, PolicySource::Sandbox, &[]); + let rev_b = compute_config_revision(Some(&policy), &settings, PolicySource::Global, &[]); assert_ne!(rev_a, rev_b); } + #[test] + fn config_revision_changes_when_external_middleware_changes() { + let policy = ProtoSandboxPolicy::default(); + let settings = HashMap::new(); + let service = openshell_core::proto::ExternalMiddlewareService { + name: "local-guard".into(), + endpoint: "http://127.0.0.1:50051".into(), + allow_insecure: true, + max_body_bytes: 1024, + }; + + let without = compute_config_revision(Some(&policy), &settings, PolicySource::Sandbox, &[]); + let with = + compute_config_revision(Some(&policy), &settings, PolicySource::Sandbox, &[service]); + assert_ne!(without, with); + } + #[test] fn config_revision_without_policy_still_hashes_settings() { let mut settings = HashMap::new(); @@ -9309,7 +9356,7 @@ mod tests { }, ); - let rev_a = compute_config_revision(None, &settings, PolicySource::Sandbox); + let rev_a = compute_config_revision(None, &settings, PolicySource::Sandbox, &[]); settings.insert( "log_level".to_string(), @@ -9321,7 +9368,7 @@ mod tests { }, ); - let rev_b = compute_config_revision(None, &settings, PolicySource::Sandbox); + let rev_b = compute_config_revision(None, &settings, PolicySource::Sandbox, &[]); assert_ne!(rev_a, rev_b); } diff --git a/crates/openshell-server/src/grpc/sandbox.rs b/crates/openshell-server/src/grpc/sandbox.rs index 04d5a4ed5..203cd7dbe 100644 --- a/crates/openshell-server/src/grpc/sandbox.rs +++ b/crates/openshell-server/src/grpc/sandbox.rs @@ -164,6 +164,7 @@ async fn handle_create_sandbox_inner( openshell_policy::ensure_sandbox_process_identity(policy); validate_no_reserved_provider_policy_keys(policy)?; validate_policy_safety(policy)?; + crate::middleware::validate_policy(state.middleware_registry.as_ref(), policy).await?; } let id = uuid::Uuid::new_v4().to_string(); diff --git a/crates/openshell-server/src/lib.rs b/crates/openshell-server/src/lib.rs index 6462ccbbf..bca8abe2e 100644 --- a/crates/openshell-server/src/lib.rs +++ b/crates/openshell-server/src/lib.rs @@ -32,6 +32,7 @@ mod defaults; mod grpc; mod http; mod inference; +mod middleware; mod multiplex; mod persistence; pub(crate) mod policy_store; @@ -53,6 +54,8 @@ mod ws_tunnel; use metrics_exporter_prometheus::PrometheusBuilder; use openshell_core::{ComputeDriverKind, Config, Error, Result}; +use openshell_supervisor_middleware::MiddlewareRegistry; +use serde::Deserialize; use std::collections::HashMap; use std::io::ErrorKind; use std::net::SocketAddr; @@ -126,6 +129,9 @@ pub struct ServerState { /// query session state to surface supervisor readiness. pub supervisor_sessions: Arc, + /// Validated built-in and operator-registered supervisor middleware. + pub middleware_registry: Arc, + /// OIDC JWKS cache for JWT validation. `None` when OIDC is not configured. pub oidc_cache: Option>, @@ -192,6 +198,7 @@ impl ServerState { ssh_connections_by_sandbox: Mutex::new(HashMap::new()), settings_mutex: tokio::sync::Mutex::new(()), supervisor_sessions, + middleware_registry: Arc::new(MiddlewareRegistry::default()), oidc_cache, sandbox_jwt_issuer: None, sandbox_jwt_authenticator: None, @@ -223,6 +230,23 @@ pub(crate) async fn run_server( return Err(Error::config("database_url is required")); } + let middleware_registrations = config_file + .as_ref() + .map(|file| { + file.openshell + .gateway + .middleware + .iter() + .map(Into::into) + .collect() + }) + .unwrap_or_default(); + let middleware_registry = Arc::new( + MiddlewareRegistry::connect_external(middleware_registrations) + .await + .map_err(|error| Error::config(format!("middleware registration failed: {error}")))?, + ); + let store = Arc::new(Store::connect(database_url).await?); let oidc_cache = if let Some(ref oidc) = config.oidc { @@ -273,6 +297,7 @@ pub(crate) async fn run_server( supervisor_sessions, oidc_cache, ); + state.middleware_registry = middleware_registry; // Load the gateway-minted sandbox JWT signing key when configured. // Optional so single-driver dev deployments without certgen continue diff --git a/crates/openshell-server/src/middleware.rs b/crates/openshell-server/src/middleware.rs new file mode 100644 index 000000000..4c94f021a --- /dev/null +++ b/crates/openshell-server/src/middleware.rs @@ -0,0 +1,43 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use openshell_core::proto::SandboxPolicy; +use openshell_supervisor_middleware::MiddlewareRegistry; +use tonic::Status; + +/// Validate implementation-owned middleware config before accepting a policy. +pub async fn validate_policy( + registry: &MiddlewareRegistry, + policy: &SandboxPolicy, +) -> Result<(), Status> { + registry + .validate_policy_configs(policy) + .await + .map_err(|error| { + Status::invalid_argument(format!("policy middleware validation failed: {error}")) + }) +} + +#[cfg(test)] +mod tests { + use super::*; + use openshell_core::proto::NetworkMiddlewareConfig; + + #[tokio::test] + async fn unregistered_external_binding_is_rejected_before_admission() { + let policy = SandboxPolicy { + network_middlewares: vec![NetworkMiddlewareConfig { + name: "guard".into(), + middleware: "example/content-guard".into(), + ..Default::default() + }], + ..Default::default() + }; + + let error = validate_policy(&MiddlewareRegistry::default(), &policy) + .await + .expect_err("unregistered binding must fail"); + assert_eq!(error.code(), tonic::Code::InvalidArgument); + assert!(error.message().contains("not registered")); + } +} diff --git a/crates/openshell-supervisor-middleware/Cargo.toml b/crates/openshell-supervisor-middleware/Cargo.toml index 4ae355894..e5e53618d 100644 --- a/crates/openshell-supervisor-middleware/Cargo.toml +++ b/crates/openshell-supervisor-middleware/Cargo.toml @@ -17,7 +17,10 @@ miette = { workspace = true } prost-types = { workspace = true } regex = { workspace = true } tokio = { workspace = true } -tonic = { workspace = true } +tonic = { workspace = true, features = ["channel", "server"] } + +[dev-dependencies] +tokio-stream = { workspace = true, features = ["net"] } [lints] workspace = true diff --git a/crates/openshell-supervisor-middleware/src/lib.rs b/crates/openshell-supervisor-middleware/src/lib.rs index a9cb52434..828179d18 100644 --- a/crates/openshell-supervisor-middleware/src/lib.rs +++ b/crates/openshell-supervisor-middleware/src/lib.rs @@ -1,9 +1,10 @@ // SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -//! In-process supervisor middleware chain execution. +//! Supervisor middleware registration and chain execution. mod builtins; +mod remote; mod service; use std::collections::{BTreeMap, HashMap, HashSet}; @@ -14,14 +15,17 @@ pub use service::InProcessMiddlewareService; use openshell_core::proto::middleware::v1::supervisor_middleware_server::SupervisorMiddleware; use openshell_core::proto::{ - Decision, Finding, HttpRequestEvaluation, HttpRequestTarget, MiddlewareBinding, - MiddlewareManifest, NetworkMiddlewareConfig, RequestContext, + Decision, ExternalMiddlewareService, Finding, HttpRequestEvaluation, HttpRequestTarget, + MiddlewareBinding, MiddlewareManifest, NetworkMiddlewareConfig, RequestContext, SandboxPolicy, + ValidateConfigRequest, }; use tokio::sync::OnceCell; use tonic::Request; pub const API_VERSION: &str = "openshell.middleware.v1"; pub const BUILTIN_SECRETS: &str = builtins::secrets::BINDING_ID; +const HTTP_REQUEST_OPERATION: &str = "HttpRequest"; +const PRE_CREDENTIALS_PHASE: &str = "pre_credentials"; /// Validate the configuration for an in-process middleware implementation. /// @@ -82,9 +86,10 @@ impl TryFrom<&NetworkMiddlewareConfig> for ChainEntry { /// A policy-selected middleware config joined with metadata reported by its /// service's `Describe` call. A missing binding is retained so `on_error` can /// decide whether the request fails open or closed. -#[derive(Debug, Clone)] +#[derive(Clone)] pub struct DescribedChainEntry { entry: ChainEntry, + service: Option>, binding: Option, max_body_bytes: usize, } @@ -182,82 +187,380 @@ fn apply_on_error( #[derive(Clone)] pub struct ChainRunner { - state: Arc, + registry: Arc, } struct MiddlewareServiceState { service: Arc, manifest: OnceCell, + operator_max_body_bytes: Option, } static IN_PROCESS_SERVICE: LazyLock> = LazyLock::new(|| { Arc::new(MiddlewareServiceState { service: Arc::new(InProcessMiddlewareService), manifest: OnceCell::new(), + operator_max_body_bytes: None, }) }); -impl Default for ChainRunner { +/// Validated middleware services available to a gateway or one supervisor. +/// +/// The registry always contains the in-process built-ins. External services +/// are connected and described before construction succeeds, so callers never +/// observe a partially registered service set. +#[derive(Clone)] +pub struct MiddlewareRegistry { + services: Arc>>, + external: Arc>, +} + +impl std::fmt::Debug for MiddlewareRegistry { + fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter + .debug_struct("MiddlewareRegistry") + .field("service_count", &self.services.len()) + .field("external_count", &self.external.len()) + .finish() + } +} + +#[derive(Clone)] +struct RegisteredExternalService { + registration: ExternalMiddlewareService, + binding_ids: Vec, +} + +impl Default for MiddlewareRegistry { fn default() -> Self { Self { - state: Arc::clone(&IN_PROCESS_SERVICE), + services: Arc::new(vec![Arc::clone(&IN_PROCESS_SERVICE)]), + external: Arc::new(Vec::new()), } } } +fn validate_registration(registration: &ExternalMiddlewareService) -> Result<()> { + if registration.name.trim().is_empty() { + return Err(miette!( + "external middleware registration name cannot be empty" + )); + } + if !registration.allow_insecure { + return Err(miette!( + "middleware registration '{}' must set allow_insecure = true; TLS is not supported in V1", + registration.name + )); + } + if !registration.endpoint.starts_with("http://") { + return Err(miette!( + "middleware registration '{}' endpoint must use http:// in the local-development-only V1", + registration.name + )); + } + if registration.max_body_bytes == 0 { + return Err(miette!( + "middleware registration '{}' max_body_bytes must be greater than zero", + registration.name + )); + } + Ok(()) +} + +fn validate_external_manifest( + registration: &ExternalMiddlewareService, + manifest: &MiddlewareManifest, + operator_max_body_bytes: usize, + known_binding_ids: &mut HashSet, +) -> Result> { + if manifest.api_version != API_VERSION { + return Err(miette!( + "middleware registration '{}' reports unsupported API version '{}'", + registration.name, + manifest.api_version + )); + } + if manifest.bindings.is_empty() { + return Err(miette!( + "middleware registration '{}' describes no bindings", + registration.name + )); + } + + let mut described_ids = Vec::with_capacity(manifest.bindings.len()); + for binding in &manifest.bindings { + if binding.id.trim().is_empty() { + return Err(miette!( + "middleware registration '{}' describes an empty binding id", + registration.name + )); + } + if binding.id.starts_with("openshell/") { + return Err(miette!( + "external middleware registration '{}' cannot claim reserved binding '{}'", + registration.name, + binding.id + )); + } + if binding.operation != HTTP_REQUEST_OPERATION || binding.phase != PRE_CREDENTIALS_PHASE { + return Err(miette!( + "middleware binding '{}' must support {HTTP_REQUEST_OPERATION}/{PRE_CREDENTIALS_PHASE}", + binding.id + )); + } + let advertised = usize::try_from(binding.max_body_bytes).map_err(|_| { + miette!( + "middleware binding '{}' reports a body limit too large for this platform", + binding.id + ) + })?; + if advertised == 0 { + return Err(miette!( + "middleware binding '{}' must advertise a non-zero body limit", + binding.id + )); + } + if operator_max_body_bytes > advertised { + return Err(miette!( + "middleware registration '{}' max_body_bytes ({operator_max_body_bytes}) exceeds binding '{}' capability ({advertised})", + registration.name, + binding.id + )); + } + if !known_binding_ids.insert(binding.id.clone()) { + return Err(miette!( + "middleware binding '{}' is described by more than one service", + binding.id + )); + } + described_ids.push(binding.id.clone()); + } + Ok(described_ids) +} + +impl MiddlewareRegistry { + /// Connect and validate every external service registration. + pub async fn connect_external(registrations: Vec) -> Result { + let mut services = vec![Arc::clone(&IN_PROCESS_SERVICE)]; + let mut external = Vec::with_capacity(registrations.len()); + let mut registration_names = HashSet::new(); + let mut binding_ids = HashSet::from([BUILTIN_SECRETS.to_string()]); + + for registration in registrations { + validate_registration(®istration)?; + if !registration_names.insert(registration.name.clone()) { + return Err(miette!( + "duplicate external middleware registration name '{}'", + registration.name + )); + } + + let operator_max_body_bytes = + usize::try_from(registration.max_body_bytes).map_err(|_| { + miette!( + "middleware registration '{}' body limit is too large for this platform", + registration.name + ) + })?; + let service = Arc::new( + remote::RemoteMiddlewareService::connect( + ®istration.name, + ®istration.endpoint, + ) + .await?, + ); + let manifest = service + .describe(Request::new(())) + .await + .map(tonic::Response::into_inner) + .map_err(|error| { + miette!( + "middleware registration '{}' Describe failed: {}", + registration.name, + safe_reason(&error.to_string()) + ) + })?; + let described_ids = validate_external_manifest( + ®istration, + &manifest, + operator_max_body_bytes, + &mut binding_ids, + )?; + let manifest_cell = OnceCell::new(); + manifest_cell + .set(manifest) + .map_err(|_| miette!("middleware manifest cache initialized twice"))?; + services.push(Arc::new(MiddlewareServiceState { + service, + manifest: manifest_cell, + operator_max_body_bytes: Some(operator_max_body_bytes), + })); + external.push(RegisteredExternalService { + registration, + binding_ids: described_ids, + }); + } + + Ok(Self { + services: Arc::new(services), + external: Arc::new(external), + }) + } + + /// Validate implementation-owned configuration for every middleware entry. + pub async fn validate_policy_configs(&self, policy: &SandboxPolicy) -> Result<()> { + let runner = ChainRunner::from_registry(self.clone()); + for config in &policy.network_middlewares { + runner + .validate_config( + &config.middleware, + config.config.clone().unwrap_or_default(), + ) + .await + .map_err(|error| { + miette!( + "middleware config '{}' is invalid: {}", + config.name, + safe_reason(&error.to_string()) + ) + })?; + } + Ok(()) + } + + /// Check that every policy binding still belongs to the current static + /// registry without making a network call. + pub fn ensure_policy_bindings_registered(&self, policy: &SandboxPolicy) -> Result<()> { + for config in &policy.network_middlewares { + let registered = config.middleware == BUILTIN_SECRETS + || self.external.iter().any(|service| { + service + .binding_ids + .iter() + .any(|binding| binding == &config.middleware) + }); + if !registered { + return Err(miette!( + "middleware binding '{}' used by config '{}' is not registered", + config.middleware, + config.name + )); + } + } + Ok(()) + } + + /// Return only external services referenced by the effective policy. + pub fn required_external_services( + &self, + policy: Option<&SandboxPolicy>, + ) -> Vec { + let Some(policy) = policy else { + return Vec::new(); + }; + let selected: HashSet<&str> = policy + .network_middlewares + .iter() + .map(|config| config.middleware.as_str()) + .collect(); + self.external + .iter() + .filter(|service| { + service + .binding_ids + .iter() + .any(|binding| selected.contains(binding.as_str())) + }) + .map(|service| service.registration.clone()) + .collect() + } +} + +impl Default for ChainRunner { + fn default() -> Self { + Self::from_registry(MiddlewareRegistry::default()) + } +} + impl ChainRunner { pub fn new(service: Arc) -> Self { Self { - state: Arc::new(MiddlewareServiceState { - service, - manifest: OnceCell::new(), + registry: Arc::new(MiddlewareRegistry { + services: Arc::new(vec![Arc::new(MiddlewareServiceState { + service, + manifest: OnceCell::new(), + operator_max_body_bytes: None, + })]), + external: Arc::new(Vec::new()), }), } } - async fn manifest(&self) -> Result<&MiddlewareManifest> { - self.state - .manifest - .get_or_try_init(|| async { - self.state - .service - .describe(Request::new(())) - .await - .map(tonic::Response::into_inner) - .map_err(|error| { - miette!( - "middleware Describe failed: {}", - safe_reason(&error.to_string()) - ) - }) - }) - .await + pub fn from_registry(registry: MiddlewareRegistry) -> Self { + Self { + registry: Arc::new(registry), + } + } + + async fn manifests(&self) -> Result, MiddlewareManifest)>> { + let mut manifests = Vec::with_capacity(self.registry.services.len()); + for state in self.registry.services.iter() { + let manifest = state + .manifest + .get_or_try_init(|| async { + state + .service + .describe(Request::new(())) + .await + .map(tonic::Response::into_inner) + .map_err(|error| { + miette!( + "middleware Describe failed: {}", + safe_reason(&error.to_string()) + ) + }) + }) + .await?; + manifests.push((Arc::clone(state), manifest.clone())); + } + Ok(manifests) } pub async fn describe_chain(&self, entries: &[ChainEntry]) -> Result> { - let manifest = self.manifest().await?; + let manifests = self.manifests().await?; entries .iter() .map(|entry| { - let binding = manifest - .bindings - .iter() - .find(|binding| binding.id == entry.implementation) - .cloned(); + let described = manifests.iter().find_map(|(state, manifest)| { + manifest + .bindings + .iter() + .find(|binding| binding.id == entry.implementation) + .cloned() + .map(|binding| (Arc::clone(state), binding)) + }); + let (service, binding) = described.map_or((None, None), |(service, binding)| { + (Some(service), Some(binding)) + }); let max_body_bytes = binding .as_ref() .map(|binding| { - usize::try_from(binding.max_body_bytes).map_err(|_| { + let advertised = usize::try_from(binding.max_body_bytes).map_err(|_| { miette!( "middleware binding '{}' reports a body limit too large for this platform", binding.id ) - }) + })?; + Ok::<_, miette::Report>(service + .as_ref() + .and_then(|state| state.operator_max_body_bytes) + .unwrap_or(advertised)) }) .transpose()? .unwrap_or(0); Ok(DescribedChainEntry { entry: entry.clone(), + service, binding, max_body_bytes, }) @@ -265,6 +568,44 @@ impl ChainRunner { .collect() } + pub async fn validate_config( + &self, + implementation: &str, + config: prost_types::Struct, + ) -> Result<()> { + let manifests = self.manifests().await?; + let Some((state, _)) = manifests.iter().find(|(_, manifest)| { + manifest + .bindings + .iter() + .any(|binding| binding.id == implementation) + }) else { + return Err(miette!( + "middleware binding '{implementation}' is not registered" + )); + }; + let response = state + .service + .validate_config(Request::new(ValidateConfigRequest { + api_version: API_VERSION.into(), + binding_id: implementation.into(), + config: Some(config), + })) + .await + .map(tonic::Response::into_inner) + .map_err(|error| { + miette!( + "middleware ValidateConfig failed: {}", + safe_reason(&error.to_string()) + ) + })?; + if response.valid { + Ok(()) + } else { + Err(miette!("{}", safe_reason(&response.reason))) + } + } + pub async fn evaluate( &self, entries: &[ChainEntry], @@ -320,8 +661,10 @@ impl ChainRunner { } } let evaluation = build_evaluation(entry, binding, &input, &headers, &body); - let result = match self - .state + let Some(service) = entry.service.as_ref() else { + unreachable!("described binding always has a service") + }; + let result = match service .service .evaluate_http_request(Request::new(evaluation)) .await @@ -545,7 +888,10 @@ pub(crate) fn safe_reason(reason: &str) -> String { #[cfg(test)] mod tests { use super::*; - use openshell_core::proto::middleware::v1::supervisor_middleware_server::SupervisorMiddleware; + use openshell_core::proto::middleware::v1::supervisor_middleware_server::{ + SupervisorMiddleware, SupervisorMiddlewareServer, + }; + use tokio_stream::wrappers::TcpListenerStream; fn entry(name: &str, on_error: OnError) -> ChainEntry { ChainEntry { @@ -736,7 +1082,7 @@ mod tests { async fn validate_config( &self, - _request: Request, + _request: Request, ) -> std::result::Result< tonic::Response, tonic::Status, @@ -780,6 +1126,51 @@ mod tests { } } + fn external_registration(max_body_bytes: u64) -> ExternalMiddlewareService { + ExternalMiddlewareService { + name: "local-guard-service".into(), + endpoint: "http://127.0.0.1:50051".into(), + allow_insecure: true, + max_body_bytes, + } + } + + async fn registry_with_external( + service: Arc, + registration: ExternalMiddlewareService, + ) -> MiddlewareRegistry { + let manifest = service + .describe(Request::new(())) + .await + .expect("describe test service") + .into_inner(); + let operator_max_body_bytes = usize::try_from(registration.max_body_bytes).unwrap(); + let mut known = HashSet::from([BUILTIN_SECRETS.to_string()]); + let binding_ids = validate_external_manifest( + ®istration, + &manifest, + operator_max_body_bytes, + &mut known, + ) + .expect("valid external manifest"); + let manifest_cell = OnceCell::new(); + manifest_cell.set(manifest).expect("manifest cache"); + MiddlewareRegistry { + services: Arc::new(vec![ + Arc::clone(&IN_PROCESS_SERVICE), + Arc::new(MiddlewareServiceState { + service, + manifest: manifest_cell, + operator_max_body_bytes: Some(operator_max_body_bytes), + }), + ]), + external: Arc::new(vec![RegisteredExternalService { + registration, + binding_ids, + }]), + } + } + #[tokio::test] async fn descriptors_are_resolved_from_any_middleware_service() { let runner = ChainRunner::new(Arc::new(ScriptedService { @@ -815,6 +1206,138 @@ mod tests { assert!(outcome.allowed); } + #[tokio::test] + async fn mixed_builtin_and_external_chain_uses_operator_limit() { + let external = Arc::new(ScriptedService { + binding_id: "example/content-guard".into(), + max_body_bytes: 4096, + result: allow_result(), + }); + let registry = registry_with_external(external, external_registration(1024)).await; + let runner = ChainRunner::from_registry(registry); + let external_entry = ChainEntry { + name: "external".into(), + implementation: "example/content-guard".into(), + config: prost_types::Struct::default(), + on_error: OnError::FailClosed, + }; + let entries = [entry("builtin", OnError::FailClosed), external_entry]; + + let described = runner + .describe_chain(&entries) + .await + .expect("describe chain"); + assert_eq!(described[0].max_body_bytes(), 256 * 1024); + assert_eq!(described[1].max_body_bytes(), 1024); + + let outcome = runner + .evaluate_described(&described, input(r#"password="top-secret""#)) + .await + .expect("evaluate mixed chain"); + assert!(outcome.allowed); + assert_eq!(outcome.applied.len(), 2); + assert_eq!( + String::from_utf8(outcome.body).expect("utf8"), + r#"password="[REDACTED]""# + ); + } + + #[test] + fn external_manifest_rejects_operator_limit_above_capability() { + let registration = external_registration(4097); + let manifest = MiddlewareManifest { + api_version: API_VERSION.into(), + name: "example/service".into(), + service_version: "test".into(), + bindings: vec![MiddlewareBinding { + id: "example/content-guard".into(), + operation: HTTP_REQUEST_OPERATION.into(), + phase: PRE_CREDENTIALS_PHASE.into(), + max_body_bytes: 4096, + }], + }; + let error = validate_external_manifest( + ®istration, + &manifest, + 4097, + &mut HashSet::from([BUILTIN_SECRETS.to_string()]), + ) + .expect_err("operator limit must fit capability"); + assert!(error.to_string().contains("exceeds")); + } + + #[test] + fn external_registration_requires_explicit_insecure_opt_in() { + let mut registration = external_registration(4096); + registration.allow_insecure = false; + let error = validate_registration(®istration).expect_err("opt-in required"); + assert!(error.to_string().contains("allow_insecure")); + } + + #[tokio::test] + async fn external_registry_invokes_remote_service_over_grpc() { + let listener = tokio::net::TcpListener::bind("127.0.0.1:0") + .await + .expect("bind test middleware"); + let address = listener.local_addr().expect("test middleware address"); + let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel(); + let server = tonic::transport::Server::builder() + .add_service(SupervisorMiddlewareServer::new(ScriptedService { + binding_id: "example/content-guard".into(), + max_body_bytes: 4096, + result: allow_result(), + })) + .serve_with_incoming_shutdown(TcpListenerStream::new(listener), async { + let _ = shutdown_rx.await; + }); + let server_task = tokio::spawn(server); + + let mut registration = external_registration(1024); + registration.endpoint = format!("http://{address}"); + let registry = MiddlewareRegistry::connect_external(vec![registration.clone()]) + .await + .expect("connect external middleware"); + let policy = SandboxPolicy { + network_middlewares: vec![NetworkMiddlewareConfig { + name: "guard".into(), + middleware: "example/content-guard".into(), + config: Some(prost_types::Struct::default()), + on_error: "fail_closed".into(), + endpoints: None, + }], + ..Default::default() + }; + + registry + .validate_policy_configs(&policy) + .await + .expect("remote config validates"); + assert_eq!( + registry.required_external_services(Some(&policy)), + vec![registration] + ); + + let outcome = ChainRunner::from_registry(registry) + .evaluate( + &[ChainEntry { + name: "guard".into(), + implementation: "example/content-guard".into(), + config: prost_types::Struct::default(), + on_error: OnError::FailClosed, + }], + input("hello"), + ) + .await + .expect("remote evaluation"); + assert!(outcome.allowed); + + let _ = shutdown_tx.send(()); + server_task + .await + .expect("join test middleware") + .expect("serve"); + } + #[tokio::test] async fn deny_decision_short_circuits_chain() { let runner = ChainRunner::new(Arc::new(scripted_service( diff --git a/crates/openshell-supervisor-middleware/src/remote.rs b/crates/openshell-supervisor-middleware/src/remote.rs new file mode 100644 index 000000000..dd147788b --- /dev/null +++ b/crates/openshell-supervisor-middleware/src/remote.rs @@ -0,0 +1,91 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use std::time::Duration; + +use miette::{IntoDiagnostic, Result, WrapErr}; +use openshell_core::proto::middleware::v1::supervisor_middleware_client::SupervisorMiddlewareClient; +use openshell_core::proto::middleware::v1::supervisor_middleware_server::SupervisorMiddleware; +use openshell_core::proto::{ + HttpRequestEvaluation, HttpRequestResult, MiddlewareManifest, ValidateConfigRequest, + ValidateConfigResponse, +}; +use tonic::transport::{Channel, Endpoint}; +use tonic::{Request, Response, Status}; + +const CONNECT_TIMEOUT: Duration = Duration::from_secs(5); +const RPC_TIMEOUT: Duration = Duration::from_secs(5); + +#[derive(Clone)] +pub struct RemoteMiddlewareService { + registration_name: String, + client: SupervisorMiddlewareClient, +} + +impl RemoteMiddlewareService { + pub async fn connect(registration_name: &str, endpoint: &str) -> Result { + let channel = Endpoint::from_shared(endpoint.to_string()) + .into_diagnostic() + .wrap_err_with(|| { + format!("middleware registration '{registration_name}' has an invalid endpoint") + })? + .connect_timeout(CONNECT_TIMEOUT) + .connect() + .await + .into_diagnostic() + .wrap_err_with(|| { + format!( + "middleware registration '{registration_name}' could not connect to {endpoint}" + ) + })?; + Ok(Self { + registration_name: registration_name.to_string(), + client: SupervisorMiddlewareClient::new(channel), + }) + } + + async fn with_timeout( + &self, + operation: &'static str, + future: impl Future, Status>>, + ) -> std::result::Result, Status> { + tokio::time::timeout(RPC_TIMEOUT, future) + .await + .map_err(|_| { + Status::deadline_exceeded(format!( + "middleware '{}' {operation} timed out", + self.registration_name + )) + })? + } +} + +#[tonic::async_trait] +impl SupervisorMiddleware for RemoteMiddlewareService { + async fn describe( + &self, + request: Request<()>, + ) -> std::result::Result, Status> { + let mut client = self.client.clone(); + self.with_timeout("Describe", client.describe(request)) + .await + } + + async fn validate_config( + &self, + request: Request, + ) -> std::result::Result, Status> { + let mut client = self.client.clone(); + self.with_timeout("ValidateConfig", client.validate_config(request)) + .await + } + + async fn evaluate_http_request( + &self, + request: Request, + ) -> std::result::Result, Status> { + let mut client = self.client.clone(); + self.with_timeout("EvaluateHttpRequest", client.evaluate_http_request(request)) + .await + } +} diff --git a/crates/openshell-supervisor-network/src/l7/relay.rs b/crates/openshell-supervisor-network/src/l7/relay.rs index 8383b6bb2..84853f751 100644 --- a/crates/openshell-supervisor-network/src/l7/relay.rs +++ b/crates/openshell-supervisor-network/src/l7/relay.rs @@ -454,35 +454,41 @@ where if allowed || (config.enforcement == EnforcementMode::Audit && !force_deny) { let chain = engine.query_middleware_chain(&middleware_network_input(ctx))?; - let req = - match apply_middleware_chain(req, client, ctx, chain, engine.generation_guard()) - .await? - { - MiddlewareApplyResult::Allowed(req) => req, - MiddlewareApplyResult::Denied(reason) => { - crate::l7::rest::RestProvider::default() - .deny_with_redacted_target( - &crate::l7::provider::L7Request { - action: request_info.action.clone(), - target: redacted_target.clone(), - query_params: request_info.query_params.clone(), - raw_header: Vec::new(), - body_length: crate::l7::provider::BodyLength::None, - }, - &ctx.policy_name, - &reason, - client, - Some(&redacted_target), - Some(crate::l7::rest::DenyResponseContext { - host: Some(&ctx.host), - port: Some(ctx.port), - binary: Some(&ctx.binary_path), - }), - ) - .await?; - return Ok(()); - } - }; + let req = match apply_middleware_chain( + req, + client, + ctx, + chain, + engine.middleware_runner(), + engine.generation_guard(), + ) + .await? + { + MiddlewareApplyResult::Allowed(req) => req, + MiddlewareApplyResult::Denied(reason) => { + crate::l7::rest::RestProvider::default() + .deny_with_redacted_target( + &crate::l7::provider::L7Request { + action: request_info.action.clone(), + target: redacted_target.clone(), + query_params: request_info.query_params.clone(), + raw_header: Vec::new(), + body_length: crate::l7::provider::BodyLength::None, + }, + &ctx.policy_name, + &reason, + client, + Some(&redacted_target), + Some(crate::l7::rest::DenyResponseContext { + host: Some(&ctx.host), + port: Some(ctx.port), + binary: Some(&ctx.binary_path), + }), + ) + .await?; + return Ok(()); + } + }; let outcome = crate::l7::rest::relay_http_request_with_options_guarded( &req, client, @@ -786,9 +792,11 @@ pub(crate) async fn apply_middleware_chain, + runner: &openshell_supervisor_middleware::ChainRunner, generation_guard: &PolicyGenerationGuard, ) -> Result { - apply_middleware_chain_for_scheme(req, client, ctx, "https", chain, generation_guard).await + apply_middleware_chain_for_scheme(req, client, ctx, "https", chain, runner, generation_guard) + .await } pub(crate) async fn apply_middleware_chain_for_scheme( @@ -797,12 +805,12 @@ pub(crate) async fn apply_middleware_chain_for_scheme, + runner: &openshell_supervisor_middleware::ChainRunner, generation_guard: &PolicyGenerationGuard, ) -> Result { if chain.is_empty() { return Ok(MiddlewareApplyResult::Allowed(req)); } - let runner = openshell_supervisor_middleware::ChainRunner::default(); let chain = runner.describe_chain(&chain).await?; let max_body_bytes = middleware_chain_body_limit(&chain).expect("non-empty middleware chain has a body limit"); @@ -1221,35 +1229,41 @@ where if allowed || config.enforcement == EnforcementMode::Audit { let chain = engine.query_middleware_chain(&middleware_network_input(ctx))?; - let req = - match apply_middleware_chain(req, client, ctx, chain, engine.generation_guard()) - .await? - { - MiddlewareApplyResult::Allowed(req) => req, - MiddlewareApplyResult::Denied(reason) => { - provider - .deny_with_redacted_target( - &crate::l7::provider::L7Request { - action: request_info.action.clone(), - target: redacted_target.clone(), - query_params: request_info.query_params.clone(), - raw_header: Vec::new(), - body_length: crate::l7::provider::BodyLength::None, - }, - &ctx.policy_name, - &reason, - client, - Some(&redacted_target), - Some(crate::l7::rest::DenyResponseContext { - host: Some(&ctx.host), - port: Some(ctx.port), - binary: Some(&ctx.binary_path), - }), - ) - .await?; - return Ok(()); - } - }; + let req = match apply_middleware_chain( + req, + client, + ctx, + chain, + engine.middleware_runner(), + engine.generation_guard(), + ) + .await? + { + MiddlewareApplyResult::Allowed(req) => req, + MiddlewareApplyResult::Denied(reason) => { + provider + .deny_with_redacted_target( + &crate::l7::provider::L7Request { + action: request_info.action.clone(), + target: redacted_target.clone(), + query_params: request_info.query_params.clone(), + raw_header: Vec::new(), + body_length: crate::l7::provider::BodyLength::None, + }, + &ctx.policy_name, + &reason, + client, + Some(&redacted_target), + Some(crate::l7::rest::DenyResponseContext { + host: Some(&ctx.host), + port: Some(ctx.port), + binary: Some(&ctx.binary_path), + }), + ) + .await?; + return Ok(()); + } + }; let req_with_auth = match crate::l7::token_grant_injection::inject_if_needed(req, ctx).await { Ok(req) => req, @@ -1490,35 +1504,41 @@ where if allowed || (config.enforcement == EnforcementMode::Audit && !force_deny) { let chain = engine.query_middleware_chain(&middleware_network_input(ctx))?; - let req = - match apply_middleware_chain(req, client, ctx, chain, engine.generation_guard()) - .await? - { - MiddlewareApplyResult::Allowed(req) => req, - MiddlewareApplyResult::Denied(reason) => { - crate::l7::rest::RestProvider::default() - .deny_with_redacted_target( - &crate::l7::provider::L7Request { - action: request_info.action.clone(), - target: redacted_target.clone(), - query_params: request_info.query_params.clone(), - raw_header: Vec::new(), - body_length: crate::l7::provider::BodyLength::None, - }, - &ctx.policy_name, - &reason, - client, - Some(&redacted_target), - Some(crate::l7::rest::DenyResponseContext { - host: Some(&ctx.host), - port: Some(ctx.port), - binary: Some(&ctx.binary_path), - }), - ) - .await?; - return Ok(()); - } - }; + let req = match apply_middleware_chain( + req, + client, + ctx, + chain, + engine.middleware_runner(), + engine.generation_guard(), + ) + .await? + { + MiddlewareApplyResult::Allowed(req) => req, + MiddlewareApplyResult::Denied(reason) => { + crate::l7::rest::RestProvider::default() + .deny_with_redacted_target( + &crate::l7::provider::L7Request { + action: request_info.action.clone(), + target: redacted_target.clone(), + query_params: request_info.query_params.clone(), + raw_header: Vec::new(), + body_length: crate::l7::provider::BodyLength::None, + }, + &ctx.policy_name, + &reason, + client, + Some(&redacted_target), + Some(crate::l7::rest::DenyResponseContext { + host: Some(&ctx.host), + port: Some(ctx.port), + binary: Some(&ctx.binary_path), + }), + ) + .await?; + return Ok(()); + } + }; // Future MCP response/SSE introspection or rewrite would hook here // before returning upstream bytes. The current policy schema has no // trusted-annotations or version-profile field, so MCP responses and @@ -1714,35 +1734,41 @@ where if allowed || (config.enforcement == EnforcementMode::Audit && !force_deny) { let chain = engine.query_middleware_chain(&middleware_network_input(ctx))?; - let req = - match apply_middleware_chain(req, client, ctx, chain, engine.generation_guard()) - .await? - { - MiddlewareApplyResult::Allowed(req) => req, - MiddlewareApplyResult::Denied(reason) => { - crate::l7::rest::RestProvider::default() - .deny_with_redacted_target( - &crate::l7::provider::L7Request { - action: request_info.action.clone(), - target: redacted_target.clone(), - query_params: request_info.query_params.clone(), - raw_header: Vec::new(), - body_length: crate::l7::provider::BodyLength::None, - }, - &ctx.policy_name, - &reason, - client, - Some(&redacted_target), - Some(crate::l7::rest::DenyResponseContext { - host: Some(&ctx.host), - port: Some(ctx.port), - binary: Some(&ctx.binary_path), - }), - ) - .await?; - return Ok(()); - } - }; + let req = match apply_middleware_chain( + req, + client, + ctx, + chain, + engine.middleware_runner(), + engine.generation_guard(), + ) + .await? + { + MiddlewareApplyResult::Allowed(req) => req, + MiddlewareApplyResult::Denied(reason) => { + crate::l7::rest::RestProvider::default() + .deny_with_redacted_target( + &crate::l7::provider::L7Request { + action: request_info.action.clone(), + target: redacted_target.clone(), + query_params: request_info.query_params.clone(), + raw_header: Vec::new(), + body_length: crate::l7::provider::BodyLength::None, + }, + &ctx.policy_name, + &reason, + client, + Some(&redacted_target), + Some(crate::l7::rest::DenyResponseContext { + host: Some(&ctx.host), + port: Some(ctx.port), + binary: Some(&ctx.binary_path), + }), + ) + .await?; + return Ok(()); + } + }; let outcome = crate::l7::rest::relay_http_request_with_resolver_guarded( &req, client, @@ -2170,7 +2196,9 @@ where if generation != generation_guard.captured_generation() { return Ok(()); } - match apply_middleware_chain(req, client, ctx, chain, generation_guard).await? { + let runner = engine.middleware_runner()?; + match apply_middleware_chain(req, client, ctx, chain, &runner, generation_guard).await? + { MiddlewareApplyResult::Allowed(req) => req, MiddlewareApplyResult::Denied(reason) => { crate::l7::rest::RestProvider::default() diff --git a/crates/openshell-supervisor-network/src/opa.rs b/crates/openshell-supervisor-network/src/opa.rs index c4e773996..9e1427dcd 100644 --- a/crates/openshell-supervisor-network/src/opa.rs +++ b/crates/openshell-supervisor-network/src/opa.rs @@ -13,11 +13,11 @@ use openshell_core::policy::{ }; use openshell_core::proto::SandboxPolicy as ProtoSandboxPolicy; use openshell_policy::L7ConfigStanza; -use openshell_supervisor_middleware::ChainEntry; +use openshell_supervisor_middleware::{ChainEntry, ChainRunner, MiddlewareRegistry}; use std::collections::HashSet; use std::path::{Path, PathBuf}; use std::sync::{ - Arc, Mutex, + Arc, Mutex, RwLock, atomic::{AtomicU64, Ordering}, }; @@ -73,6 +73,7 @@ pub struct SandboxConfig { pub struct OpaEngine { engine: Mutex, generation: Arc, + middleware_runner: RwLock, } /// Generation guard captured when an HTTP tunnel or request path starts. @@ -112,6 +113,7 @@ impl PolicyGenerationGuard { pub struct TunnelPolicyEngine { engine: Mutex, generation_guard: PolicyGenerationGuard, + middleware_runner: ChainRunner, } impl TunnelPolicyEngine { @@ -135,6 +137,10 @@ impl TunnelPolicyEngine { &self.engine } + pub(crate) fn middleware_runner(&self) -> &ChainRunner { + &self.middleware_runner + } + /// Query the ordered middleware chain for a destination within this tunnel. pub fn query_middleware_chain(&self, input: &NetworkInput) -> Result> { let mut engine = self @@ -164,6 +170,7 @@ impl OpaEngine { Ok(Self { engine: Mutex::new(engine), generation: Arc::new(AtomicU64::new(0)), + middleware_runner: RwLock::new(ChainRunner::default()), }) } @@ -182,6 +189,7 @@ impl OpaEngine { Ok(Self { engine: Mutex::new(engine), generation: Arc::new(AtomicU64::new(0)), + middleware_runner: RwLock::new(ChainRunner::default()), }) } @@ -254,6 +262,7 @@ impl OpaEngine { Ok(Self { engine: Mutex::new(engine), generation: Arc::new(AtomicU64::new(0)), + middleware_runner: RwLock::new(ChainRunner::default()), }) } @@ -451,6 +460,25 @@ impl OpaEngine { self.generation.load(Ordering::Acquire) } + /// Replace the complete middleware service registry and invalidate + /// existing tunnels so subsequent requests use the new service set. + pub fn replace_middleware_registry(&self, registry: MiddlewareRegistry) -> Result<()> { + let mut runner = self + .middleware_runner + .write() + .map_err(|_| miette::miette!("middleware runner lock poisoned"))?; + *runner = ChainRunner::from_registry(registry); + self.generation.fetch_add(1, Ordering::AcqRel); + Ok(()) + } + + pub(crate) fn middleware_runner(&self) -> Result { + self.middleware_runner + .read() + .map(|runner| runner.clone()) + .map_err(|_| miette::miette!("middleware runner lock poisoned")) + } + /// Return a guard for a previously captured policy generation. pub fn generation_guard(&self, expected_generation: u64) -> Result { let generation = self.current_generation(); @@ -662,6 +690,7 @@ impl OpaEngine { captured_generation: generation, current_generation: Arc::clone(&self.generation), }, + middleware_runner: self.middleware_runner()?, }) } } @@ -2941,6 +2970,7 @@ network_policies: let engine = OpaEngine { engine: Mutex::new(rego), generation: Arc::new(AtomicU64::new(0)), + middleware_runner: RwLock::new(ChainRunner::default()), }; let input = l7_websocket_graphql_input( "realtime.graphql.com", diff --git a/crates/openshell-supervisor-network/src/proxy.rs b/crates/openshell-supervisor-network/src/proxy.rs index afec98666..8616c3b2c 100644 --- a/crates/openshell-supervisor-network/src/proxy.rs +++ b/crates/openshell-supervisor-network/src/proxy.rs @@ -4218,6 +4218,7 @@ async fn handle_forward_proxy( return Ok(()); } if !chain.is_empty() { + let middleware_runner = opa_engine.middleware_runner()?; let request = crate::l7::rest::request_from_buffered_http( method, middleware_path, @@ -4230,6 +4231,7 @@ async fn handle_forward_proxy( &l7_ctx, &scheme, chain, + &middleware_runner, &forward_generation_guard, ) .await? diff --git a/docs/reference/gateway-config.mdx b/docs/reference/gateway-config.mdx index 88b82870d..c28967190 100644 --- a/docs/reference/gateway-config.mdx +++ b/docs/reference/gateway-config.mdx @@ -103,6 +103,14 @@ guest_tls_key = "/etc/openshell/certs/client-key.pem" grpc_rate_limit_requests = 120 grpc_rate_limit_window_seconds = 60 +# Local-development-only external supervisor middleware. The endpoint must be +# reachable from both the gateway and sandbox supervisors. +[[openshell.gateway.middleware]] +name = "local-content-guard" +endpoint = "http://host.openshell.internal:50051" +allow_insecure = true +max_body_bytes = 262144 + # Gateway listener TLS (distinct from the per-driver guest_tls_*). [openshell.gateway.tls] cert_path = "/etc/openshell/certs/gateway.pem" @@ -140,6 +148,26 @@ Local Docker, Podman, and VM gateways can also set `[openshell.gateway.mtls_auth `[openshell.gateway.auth] allow_unauthenticated_users = true` is an unsafe local-development and trusted-proxy escape hatch. It accepts user-facing CLI/API calls without OIDC or mTLS credentials while sandbox supervisors still authenticate with gateway-minted sandbox JWTs. Leave it false for shared and production gateways. +## External Supervisor Middleware + +Register external supervisor middleware with one or more `[[openshell.gateway.middleware]]` entries. Registration is static and operator-owned; changing it requires restarting the gateway. + +```toml +[[openshell.gateway.middleware]] +name = "local-content-guard" +endpoint = "http://host.openshell.internal:50051" +allow_insecure = true +max_body_bytes = 262144 +``` + +Each service implements the supervisor middleware gRPC contract and may expose multiple binding IDs through `Describe`. Policies reference those binding IDs, not the registration `name`. The gateway rejects duplicate binding IDs across services and prevents external services from claiming the reserved `openshell/` namespace. + +The gateway connects to every registered service and validates `Describe` before it starts. The service must therefore be running before the gateway. Policy creation and full policy updates call `ValidateConfig`; an unavailable service or invalid middleware configuration rejects the policy before persistence. + +`max_body_bytes` is the operator limit for every binding exposed by the service. It must be greater than zero and no larger than each binding's advertised limit. OpenShell rejects an oversized value instead of silently clamping it. + +External middleware is a local-development preview. The endpoint must use plaintext `http://`, and `allow_insecure = true` is required as an explicit acknowledgement that inspected request content is sent without transport encryption or peer authentication. TLS, authentication, health checks, and runtime registration are not supported. The endpoint must be reachable from both the gateway and sandbox supervisors; use `host.openshell.internal` or another shared address when both runtimes resolve it. + `image_pull_policy` is intentionally not a shared gateway key. Kubernetes and Docker use `Always`, `IfNotPresent`, or `Never`. Podman uses `always`, `missing`, `never`, or `newer`. Set it inside the relevant driver table. ## Driver References diff --git a/docs/reference/policy-schema.mdx b/docs/reference/policy-schema.mdx index 049ad4797..69ff06dc3 100644 --- a/docs/reference/policy-schema.mdx +++ b/docs/reference/policy-schema.mdx @@ -20,6 +20,7 @@ filesystem_policy: { ... } landlock: { ... } process: { ... } network_policies: { ... } +network_middlewares: [ ... ] ``` | Field | Type | Required | Category | Description | @@ -29,6 +30,7 @@ network_policies: { ... } | `landlock` | object | No | Static | Configures Landlock LSM enforcement behavior. | | `process` | object | No | Static | Sets the user and group the agent process runs as. | | `network_policies` | map | No | Dynamic | Declares which binaries can reach which network endpoints. | +| `network_middlewares` | list | No | Dynamic | Selects ordered HTTP request middleware by destination host. | Static fields are set at sandbox creation time. Changing them requires destroying and recreating the sandbox. Dynamic fields can be updated on a running sandbox with `openshell policy update` for incremental merges or `openshell policy set` for full replacement, and take effect without restarting. @@ -472,7 +474,35 @@ Identifies an executable that is permitted to use the associated endpoints. |---|---|---|---| | `path` | string | Yes | Filesystem path to the executable. Supports glob patterns with `*` and `**`. For example, `/sandbox/.vscode-server/**` matches any executable under that directory tree. | -### Full Example +## Network Middleware + +**Category:** Dynamic + +An ordered list of middleware configs selected after network and L7 policy admit an HTTP request. Middleware selection is independent of the network policy entry that admitted the request. Every matching config runs once in list order before provider credential injection. + +```yaml showLineNumbers={false} +network_middlewares: + - name: redact-secrets + middleware: openshell/secrets + config: + secrets: redact + on_error: fail_closed + endpoints: + include: ["*.example.com"] + exclude: ["trusted.example.com"] +``` + +| Field | Type | Required | Description | +|---|---|---|---| +| `name` | string | Yes | Policy-local config name. Names must be unique within the list. | +| `middleware` | string | Yes | Built-in or operator-registered binding ID. `openshell/` is reserved for built-ins. | +| `config` | object | No | Implementation-owned configuration validated by the selected middleware. | +| `on_error` | string | No | `fail_closed` denies the request when the stage fails; `fail_open` skips the failed stage. Defaults to `fail_closed`. | +| `endpoints` | object | Yes | Host selector with required non-empty `include` and optional `exclude` lists. Exclusions take precedence. | + +Host selectors use the same case-insensitive exact and DNS glob semantics as network endpoints. Middleware runs only on HTTP requests the supervisor parses. A selector that can require middleware on a `tls: skip` endpoint is rejected because OpenShell cannot inspect that traffic. + +## Full Example The following policy grants read-only GitHub API access and npm registry access: diff --git a/docs/sandboxes/policies.mdx b/docs/sandboxes/policies.mdx index 212ba76fc..295acb64c 100644 --- a/docs/sandboxes/policies.mdx +++ b/docs/sandboxes/policies.mdx @@ -12,7 +12,7 @@ Use this page to apply and iterate policy changes on running sandboxes. For a fu ## Policy Structure -A policy has static sections `filesystem_policy`, `landlock`, and `process` that are locked at sandbox creation, and a dynamic section `network_policies` that is hot-reloadable on a running sandbox. +A policy has static sections `filesystem_policy`, `landlock`, and `process` that are locked at sandbox creation, and dynamic `network_policies` and `network_middlewares` sections that are hot-reloadable on a running sandbox. ```yaml wordWrap showLineNumbers={false} version: 1 @@ -44,6 +44,17 @@ network_policies: binaries: - path: /usr/bin/curl +# Dynamic: ordered middleware selected independently by admitted host. +network_middlewares: + - name: redact-secrets + middleware: openshell/secrets + config: + secrets: redact + on_error: fail_closed + endpoints: + include: ["api.example.com"] + exclude: [] + ``` Static sections are locked at sandbox creation. Changing them requires destroying and recreating the sandbox. @@ -57,6 +68,29 @@ Raw streams are connection-scoped and outside L7 live-reload guarantees. This in | `landlock` | Static | Configures Landlock LSM enforcement behavior. Set `compatibility` to `best_effort` (skip individual inaccessible paths while applying remaining rules) or `hard_requirement` (fail if any path is inaccessible or the required kernel ABI is unavailable). Refer to the [Policy Schema Reference](/reference/policy-schema#landlock) for the full behavior table. | | `process` | Static | Sets the OS-level identity for the agent process. `run_as_user` and `run_as_group` default to `sandbox`. Root (`root` or `0`) is rejected. The agent also runs with seccomp filters that block dangerous system calls. | | `network_policies` | Dynamic | Controls network access for ordinary outbound traffic from the sandbox. Each block has a name, a list of endpoints (host, port, protocol, and optional rules), and a list of binaries allowed to use those endpoints.
Every outbound connection except `https://inference.local` goes through the proxy, which queries the [policy engine](/about/how-it-works#core-components) with the destination and calling binary. A connection is allowed only when both match an entry in the same policy block.
For endpoints with `protocol: rest`, the proxy auto-detects TLS and terminates it so each HTTP request can be checked against that endpoint's `rules` (method and path). For endpoints with `protocol: websocket`, the proxy validates the RFC 6455 upgrade and evaluates `GET` rules for the handshake plus either `WEBSOCKET_TEXT` rules for raw client text messages or GraphQL operation rules for GraphQL-over-WebSocket messages. Set `websocket_credential_rewrite: true` only when a WebSocket or REST compatibility endpoint must keep placeholder credentials in sandbox-owned text frames and resolve them at the OpenShell relay boundary.
Endpoints without `protocol` allow the TCP stream through without inspecting payloads.
If no endpoint matches, the connection is denied. Configure managed inference separately through [Inference Routing](/sandboxes/inference-routing). | +| `network_middlewares` | Dynamic | Declares ordered HTTP request middleware configs. After network and L7 policy admit a request, OpenShell matches each config's host selectors independently and runs matching entries in declaration order before credential injection. | + +## Supervisor Middleware + +Supervisor middleware can inspect, deny, or replace admitted HTTP request bodies before provider credentials are injected. Middleware selection is independent of the `network_policies` rule that admitted the request: each `network_middlewares` entry matches the destination host through `endpoints.include` and `endpoints.exclude`. + +```yaml +network_middlewares: + - name: redact-secrets + middleware: openshell/secrets + config: + secrets: redact + on_error: fail_closed + endpoints: + include: ["*.example.com"] + exclude: ["trusted.example.com"] +``` + +Matching entries run once each in top-level declaration order. Config names must be unique. Different config names may use the same implementation and run as distinct stages. `exclude` takes precedence over `include`. + +`openshell/secrets` is built into the supervisor. External binding IDs must be registered by the gateway operator before a policy can reference them; see [External Supervisor Middleware](/reference/gateway-config#external-supervisor-middleware). The gateway calls the implementation's `ValidateConfig` before accepting the policy. + +`on_error` defaults to `fail_closed`. Use `fail_open` only when skipping a failed middleware is acceptable. Middleware applies only to HTTP traffic the supervisor can parse and inspect; policy validation rejects a required selector that can cover a `tls: skip` endpoint. ## Baseline Filesystem Paths diff --git a/proto/sandbox.proto b/proto/sandbox.proto index 04cbd6776..afec58723 100644 --- a/proto/sandbox.proto +++ b/proto/sandbox.proto @@ -352,4 +352,20 @@ message GetSandboxConfigResponse { // Fingerprint for provider credential inputs attached to this sandbox. // Changes when attached provider names or attached provider records change. uint64 provider_env_revision = 8; + // Operator-registered external middleware services required by the effective + // policy. Built-in middleware is not included. + repeated ExternalMiddlewareService external_middleware = 9; +} + +// Connection details for one operator-registered external middleware service. +// V1 supports only explicitly enabled plaintext gRPC for local development. +message ExternalMiddlewareService { + // Operator-facing registration name used for diagnostics. + string name = 1; + // gRPC endpoint reachable from the sandbox supervisor. + string endpoint = 2; + // Explicit acknowledgement that request content is sent without TLS. + bool allow_insecure = 3; + // Operator-owned body limit applied to every binding exposed by the service. + uint64 max_body_bytes = 4; } From 762887a6eb72c9ec5c668a665c7b0c6e327b2b60 Mon Sep 17 00:00:00 2001 From: Piotr Mlocek Date: Mon, 29 Jun 2026 15:38:08 -0700 Subject: [PATCH 08/27] refactor(supervisor-middleware): clarify service contract Signed-off-by: Piotr Mlocek --- architecture/sandbox.md | 4 +- crates/openshell-core/src/grpc_client.rs | 4 +- crates/openshell-sandbox/src/lib.rs | 27 ++++---- crates/openshell-server/src/config_file.rs | 10 +-- crates/openshell-server/src/grpc/policy.rs | 17 +++-- crates/openshell-server/src/lib.rs | 2 +- .../src/lib.rs | 62 +++++++++---------- docs/reference/gateway-config.mdx | 12 ++-- docs/reference/policy-schema.mdx | 4 ++ docs/sandboxes/policies.mdx | 6 +- proto/sandbox.proto | 10 +-- 11 files changed, 85 insertions(+), 73 deletions(-) diff --git a/architecture/sandbox.md b/architecture/sandbox.md index deec00f32..d63c4fbaa 100644 --- a/architecture/sandbox.md +++ b/architecture/sandbox.md @@ -69,7 +69,7 @@ are relayed but are not currently parsed for policy enforcement. For admitted HTTP requests, the proxy can run an ordered supervisor middleware chain before credential injection. Host selectors choose the chain independently of the network rule that admitted the request. Built-ins run in-process; -operator-registered external services are called directly from the supervisor +operator-registered services are called directly from the supervisor over the common middleware gRPC contract. The gateway validates external service capabilities and policy-owned config before delivery. Supervisors keep the last-known-good service registry when a live config reload fails. @@ -184,7 +184,7 @@ quickly. - If gateway config polling fails, the sandbox keeps its last-known-good policy. - If a live policy update is invalid, the supervisor rejects it and keeps the current policy. -- If an external middleware call fails, the selected config's `on_error` +- If an operator-run middleware call fails, the selected config's `on_error` behavior decides whether to deny the request or continue without that stage. - Existing raw byte streams are connection scoped. Dynamic policy changes apply to new connections or the next parsed HTTP request where the proxy can safely diff --git a/crates/openshell-core/src/grpc_client.rs b/crates/openshell-core/src/grpc_client.rs index 57f72dca6..836c7880c 100644 --- a/crates/openshell-core/src/grpc_client.rs +++ b/crates/openshell-core/src/grpc_client.rs @@ -728,7 +728,7 @@ pub struct SettingsPollResult { /// When `policy_source` is `Global`, the version of the global policy revision. pub global_policy_version: u32, pub provider_env_revision: u64, - pub external_middleware: Vec, + pub supervisor_middleware_services: Vec, } pub struct ProviderEnvironmentResult { @@ -773,7 +773,7 @@ impl CachedOpenShellClient { settings: inner.settings, global_policy_version: inner.global_policy_version, provider_env_revision: inner.provider_env_revision, - external_middleware: inner.external_middleware, + supervisor_middleware_services: inner.supervisor_middleware_services, }) } diff --git a/crates/openshell-sandbox/src/lib.rs b/crates/openshell-sandbox/src/lib.rs index 30da1c292..5e640d73f 100644 --- a/crates/openshell-sandbox/src/lib.rs +++ b/crates/openshell-sandbox/src/lib.rs @@ -1483,8 +1483,8 @@ async fn load_policy( info!("Creating OPA engine from proto policy data"); let engine = OpaEngine::from_proto(&proto_policy)?; let middleware_registry = - openshell_supervisor_middleware::MiddlewareRegistry::connect_external( - sandbox_config.external_middleware, + openshell_supervisor_middleware::MiddlewareRegistry::connect_services( + sandbox_config.supervisor_middleware_services, ) .await?; engine.replace_middleware_registry(middleware_registry)?; @@ -1638,7 +1638,7 @@ async fn run_policy_poll_loop(ctx: PolicyPollLoopContext) -> Result<()> { let mut current_config_revision: u64 = 0; let mut current_provider_env_revision: u64 = ctx.provider_credentials.snapshot().revision; let mut current_policy_hash = String::new(); - let mut current_external_middleware = Vec::new(); + let mut current_middleware_services = Vec::new(); let mut current_settings: std::collections::HashMap< String, openshell_core::proto::EffectiveSetting, @@ -1650,7 +1650,7 @@ async fn run_policy_poll_loop(ctx: PolicyPollLoopContext) -> Result<()> { apply_ocsf_json_setting(&ctx.ocsf_enabled, &result.settings); current_config_revision = result.config_revision; current_policy_hash = result.policy_hash.clone(); - current_external_middleware = result.external_middleware; + current_middleware_services = result.supervisor_middleware_services; current_settings = result.settings; debug!( config_revision = current_config_revision, @@ -1680,7 +1680,8 @@ async fn run_policy_poll_loop(ctx: PolicyPollLoopContext) -> Result<()> { } let policy_changed = result.policy_hash != current_policy_hash; - let middleware_changed = result.external_middleware != current_external_middleware; + let middleware_changed = + result.supervisor_middleware_services != current_middleware_services; // Log which settings changed. log_setting_changes(¤t_settings, &result.settings); @@ -1740,26 +1741,26 @@ async fn run_policy_poll_loop(ctx: PolicyPollLoopContext) -> Result<()> { } if middleware_changed { - match openshell_supervisor_middleware::MiddlewareRegistry::connect_external( - result.external_middleware.clone(), + match openshell_supervisor_middleware::MiddlewareRegistry::connect_services( + result.supervisor_middleware_services.clone(), ) .await .and_then(|registry| ctx.opa_engine.replace_middleware_registry(registry)) { Ok(()) => { - current_external_middleware = result.external_middleware.clone(); + current_middleware_services = result.supervisor_middleware_services.clone(); ocsf_emit!( ConfigStateChangeBuilder::new(ocsf_ctx()) .severity(SeverityId::Informational) .status(StatusId::Success) .state(StateId::Enabled, "loaded") .unmapped( - "external_middleware_count", - serde_json::json!(current_external_middleware.len()) + "supervisor_middleware_service_count", + serde_json::json!(current_middleware_services.len()) ) .message(format!( - "External middleware registry reloaded [service_count:{}]", - current_external_middleware.len() + "Supervisor middleware registry reloaded [service_count:{}]", + current_middleware_services.len() )) .build() ); @@ -1771,7 +1772,7 @@ async fn run_policy_poll_loop(ctx: PolicyPollLoopContext) -> Result<()> { .status(StatusId::Failure) .state(StateId::Other, "failed") .message(format!( - "External middleware registry reload failed, keeping last-known-good registry [error:{error}]" + "Supervisor middleware registry reload failed, keeping last-known-good registry [error:{error}]" )) .build() ); diff --git a/crates/openshell-server/src/config_file.rs b/crates/openshell-server/src/config_file.rs index 13c7e9ebb..4b0fbc919 100644 --- a/crates/openshell-server/src/config_file.rs +++ b/crates/openshell-server/src/config_file.rs @@ -25,7 +25,7 @@ use std::net::SocketAddr; use std::path::{Path, PathBuf}; use openshell_core::config::ComputeDriverKind; -use openshell_core::proto::ExternalMiddlewareService; +use openshell_core::proto::SupervisorMiddlewareService; use openshell_core::{GatewayAuthConfig, GatewayJwtConfig, MtlsAuthConfig, OidcConfig, TlsConfig}; use serde::{Deserialize, Serialize}; @@ -153,7 +153,7 @@ pub struct GatewayFileSection { pub gateway_jwt: Option, // ── Supervisor middleware ───────────────────────────────────────────── - /// Statically registered external middleware services. Registration is + /// Statically registered supervisor middleware services. Registration is /// operator-owned and changes require a gateway restart. #[serde(default)] pub middleware: Vec, @@ -167,7 +167,7 @@ pub struct GatewayFileSection { pub database_url: Option, } -/// One `[[openshell.gateway.middleware]]` external middleware registration. +/// One `[[openshell.gateway.middleware]]` supervisor middleware registration. #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] #[serde(deny_unknown_fields)] pub struct MiddlewareServiceFileConfig { @@ -182,7 +182,7 @@ pub struct MiddlewareServiceFileConfig { pub max_body_bytes: u64, } -impl From<&MiddlewareServiceFileConfig> for ExternalMiddlewareService { +impl From<&MiddlewareServiceFileConfig> for SupervisorMiddlewareService { fn from(config: &MiddlewareServiceFileConfig) -> Self { Self { name: config.name.clone(), @@ -435,7 +435,7 @@ allow_unauthenticated_users = true } #[test] - fn parses_external_middleware_registration() { + fn parses_supervisor_middleware_registration() { let toml = r#" [[openshell.gateway.middleware]] name = "local-guard" diff --git a/crates/openshell-server/src/grpc/policy.rs b/crates/openshell-server/src/grpc/policy.rs index c29de71ff..9587e7295 100644 --- a/crates/openshell-server/src/grpc/policy.rs +++ b/crates/openshell-server/src/grpc/policy.rs @@ -1230,14 +1230,13 @@ pub(super) async fn handle_get_sandbox_config( } let settings = merge_effective_settings(&global_settings, &sandbox_settings)?; - let external_middleware = state - .middleware_registry - .required_external_services(policy.as_ref()); + let supervisor_middleware_services = + state.middleware_registry.required_services(policy.as_ref()); let config_revision = compute_config_revision( policy.as_ref(), &settings, policy_source, - &external_middleware, + &supervisor_middleware_services, ); let provider_env_revision = compute_provider_env_revision(state.store.as_ref(), &sandbox_provider_names).await?; @@ -1251,7 +1250,7 @@ pub(super) async fn handle_get_sandbox_config( policy_source: policy_source.into(), global_policy_version, provider_env_revision, - external_middleware, + supervisor_middleware_services, })) } @@ -3144,7 +3143,7 @@ fn compute_config_revision( policy: Option<&ProtoSandboxPolicy>, settings: &HashMap, policy_source: PolicySource, - external_middleware: &[openshell_core::proto::ExternalMiddlewareService], + supervisor_middleware_services: &[openshell_core::proto::SupervisorMiddlewareService], ) -> u64 { let mut hasher = Sha256::new(); hasher.update((policy_source as i32).to_le_bytes()); @@ -3177,7 +3176,7 @@ fn compute_config_revision( } } } - let mut middleware = external_middleware.iter().collect::>(); + let mut middleware = supervisor_middleware_services.iter().collect::>(); middleware.sort_by(|left, right| left.name.cmp(&right.name)); for service in middleware { hasher.update(service.encode_to_vec()); @@ -9327,10 +9326,10 @@ mod tests { } #[test] - fn config_revision_changes_when_external_middleware_changes() { + fn config_revision_changes_when_supervisor_middleware_services_change() { let policy = ProtoSandboxPolicy::default(); let settings = HashMap::new(); - let service = openshell_core::proto::ExternalMiddlewareService { + let service = openshell_core::proto::SupervisorMiddlewareService { name: "local-guard".into(), endpoint: "http://127.0.0.1:50051".into(), allow_insecure: true, diff --git a/crates/openshell-server/src/lib.rs b/crates/openshell-server/src/lib.rs index bca8abe2e..bd974d4b4 100644 --- a/crates/openshell-server/src/lib.rs +++ b/crates/openshell-server/src/lib.rs @@ -242,7 +242,7 @@ pub(crate) async fn run_server( }) .unwrap_or_default(); let middleware_registry = Arc::new( - MiddlewareRegistry::connect_external(middleware_registrations) + MiddlewareRegistry::connect_services(middleware_registrations) .await .map_err(|error| Error::config(format!("middleware registration failed: {error}")))?, ); diff --git a/crates/openshell-supervisor-middleware/src/lib.rs b/crates/openshell-supervisor-middleware/src/lib.rs index 828179d18..ebc5817e4 100644 --- a/crates/openshell-supervisor-middleware/src/lib.rs +++ b/crates/openshell-supervisor-middleware/src/lib.rs @@ -15,9 +15,9 @@ pub use service::InProcessMiddlewareService; use openshell_core::proto::middleware::v1::supervisor_middleware_server::SupervisorMiddleware; use openshell_core::proto::{ - Decision, ExternalMiddlewareService, Finding, HttpRequestEvaluation, HttpRequestTarget, - MiddlewareBinding, MiddlewareManifest, NetworkMiddlewareConfig, RequestContext, SandboxPolicy, - ValidateConfigRequest, + Decision, Finding, HttpRequestEvaluation, HttpRequestTarget, MiddlewareBinding, + MiddlewareManifest, NetworkMiddlewareConfig, RequestContext, SandboxPolicy, + SupervisorMiddlewareService, ValidateConfigRequest, }; use tokio::sync::OnceCell; use tonic::Request; @@ -206,13 +206,13 @@ static IN_PROCESS_SERVICE: LazyLock> = LazyLock::new /// Validated middleware services available to a gateway or one supervisor. /// -/// The registry always contains the in-process built-ins. External services -/// are connected and described before construction succeeds, so callers never +/// The registry always contains the in-process built-ins. Operator-registered +/// services are connected and described before construction succeeds, so callers never /// observe a partially registered service set. #[derive(Clone)] pub struct MiddlewareRegistry { services: Arc>>, - external: Arc>, + registered_services: Arc>, } impl std::fmt::Debug for MiddlewareRegistry { @@ -220,14 +220,14 @@ impl std::fmt::Debug for MiddlewareRegistry { formatter .debug_struct("MiddlewareRegistry") .field("service_count", &self.services.len()) - .field("external_count", &self.external.len()) + .field("registered_service_count", &self.registered_services.len()) .finish() } } #[derive(Clone)] -struct RegisteredExternalService { - registration: ExternalMiddlewareService, +struct RegisteredMiddlewareService { + registration: SupervisorMiddlewareService, binding_ids: Vec, } @@ -235,15 +235,15 @@ impl Default for MiddlewareRegistry { fn default() -> Self { Self { services: Arc::new(vec![Arc::clone(&IN_PROCESS_SERVICE)]), - external: Arc::new(Vec::new()), + registered_services: Arc::new(Vec::new()), } } } -fn validate_registration(registration: &ExternalMiddlewareService) -> Result<()> { +fn validate_registration(registration: &SupervisorMiddlewareService) -> Result<()> { if registration.name.trim().is_empty() { return Err(miette!( - "external middleware registration name cannot be empty" + "supervisor middleware registration name cannot be empty" )); } if !registration.allow_insecure { @@ -268,7 +268,7 @@ fn validate_registration(registration: &ExternalMiddlewareService) -> Result<()> } fn validate_external_manifest( - registration: &ExternalMiddlewareService, + registration: &SupervisorMiddlewareService, manifest: &MiddlewareManifest, operator_max_body_bytes: usize, known_binding_ids: &mut HashSet, @@ -339,10 +339,10 @@ fn validate_external_manifest( } impl MiddlewareRegistry { - /// Connect and validate every external service registration. - pub async fn connect_external(registrations: Vec) -> Result { + /// Connect and validate every operator-provided service registration. + pub async fn connect_services(registrations: Vec) -> Result { let mut services = vec![Arc::clone(&IN_PROCESS_SERVICE)]; - let mut external = Vec::with_capacity(registrations.len()); + let mut registered_services = Vec::with_capacity(registrations.len()); let mut registration_names = HashSet::new(); let mut binding_ids = HashSet::from([BUILTIN_SECRETS.to_string()]); @@ -350,7 +350,7 @@ impl MiddlewareRegistry { validate_registration(®istration)?; if !registration_names.insert(registration.name.clone()) { return Err(miette!( - "duplicate external middleware registration name '{}'", + "duplicate supervisor middleware registration name '{}'", registration.name )); } @@ -395,7 +395,7 @@ impl MiddlewareRegistry { manifest: manifest_cell, operator_max_body_bytes: Some(operator_max_body_bytes), })); - external.push(RegisteredExternalService { + registered_services.push(RegisteredMiddlewareService { registration, binding_ids: described_ids, }); @@ -403,7 +403,7 @@ impl MiddlewareRegistry { Ok(Self { services: Arc::new(services), - external: Arc::new(external), + registered_services: Arc::new(registered_services), }) } @@ -433,7 +433,7 @@ impl MiddlewareRegistry { pub fn ensure_policy_bindings_registered(&self, policy: &SandboxPolicy) -> Result<()> { for config in &policy.network_middlewares { let registered = config.middleware == BUILTIN_SECRETS - || self.external.iter().any(|service| { + || self.registered_services.iter().any(|service| { service .binding_ids .iter() @@ -450,11 +450,11 @@ impl MiddlewareRegistry { Ok(()) } - /// Return only external services referenced by the effective policy. - pub fn required_external_services( + /// Return only operator-registered services referenced by the effective policy. + pub fn required_services( &self, policy: Option<&SandboxPolicy>, - ) -> Vec { + ) -> Vec { let Some(policy) = policy else { return Vec::new(); }; @@ -463,7 +463,7 @@ impl MiddlewareRegistry { .iter() .map(|config| config.middleware.as_str()) .collect(); - self.external + self.registered_services .iter() .filter(|service| { service @@ -491,7 +491,7 @@ impl ChainRunner { manifest: OnceCell::new(), operator_max_body_bytes: None, })]), - external: Arc::new(Vec::new()), + registered_services: Arc::new(Vec::new()), }), } } @@ -1126,8 +1126,8 @@ mod tests { } } - fn external_registration(max_body_bytes: u64) -> ExternalMiddlewareService { - ExternalMiddlewareService { + fn external_registration(max_body_bytes: u64) -> SupervisorMiddlewareService { + SupervisorMiddlewareService { name: "local-guard-service".into(), endpoint: "http://127.0.0.1:50051".into(), allow_insecure: true, @@ -1137,7 +1137,7 @@ mod tests { async fn registry_with_external( service: Arc, - registration: ExternalMiddlewareService, + registration: SupervisorMiddlewareService, ) -> MiddlewareRegistry { let manifest = service .describe(Request::new(())) @@ -1164,7 +1164,7 @@ mod tests { operator_max_body_bytes: Some(operator_max_body_bytes), }), ]), - external: Arc::new(vec![RegisteredExternalService { + registered_services: Arc::new(vec![RegisteredMiddlewareService { registration, binding_ids, }]), @@ -1294,7 +1294,7 @@ mod tests { let mut registration = external_registration(1024); registration.endpoint = format!("http://{address}"); - let registry = MiddlewareRegistry::connect_external(vec![registration.clone()]) + let registry = MiddlewareRegistry::connect_services(vec![registration.clone()]) .await .expect("connect external middleware"); let policy = SandboxPolicy { @@ -1313,7 +1313,7 @@ mod tests { .await .expect("remote config validates"); assert_eq!( - registry.required_external_services(Some(&policy)), + registry.required_services(Some(&policy)), vec![registration] ); diff --git a/docs/reference/gateway-config.mdx b/docs/reference/gateway-config.mdx index c28967190..1b3554c20 100644 --- a/docs/reference/gateway-config.mdx +++ b/docs/reference/gateway-config.mdx @@ -148,9 +148,13 @@ Local Docker, Podman, and VM gateways can also set `[openshell.gateway.mtls_auth `[openshell.gateway.auth] allow_unauthenticated_users = true` is an unsafe local-development and trusted-proxy escape hatch. It accepts user-facing CLI/API calls without OIDC or mTLS credentials while sandbox supervisors still authenticate with gateway-minted sandbox JWTs. Leave it false for shared and production gateways. -## External Supervisor Middleware +## Supervisor Middleware Services -Register external supervisor middleware with one or more `[[openshell.gateway.middleware]]` entries. Registration is static and operator-owned; changing it requires restarting the gateway. + +Supervisor middleware is a research preview. Its policy and service contracts may change without compatibility guarantees. Use it only to prototype and evaluate middleware, not for production or long-lived integrations. + + +Register operator-run supervisor middleware services with one or more `[[openshell.gateway.middleware]]` entries. Registration is static and operator-owned; changing it requires restarting the gateway. ```toml [[openshell.gateway.middleware]] @@ -160,13 +164,13 @@ allow_insecure = true max_body_bytes = 262144 ``` -Each service implements the supervisor middleware gRPC contract and may expose multiple binding IDs through `Describe`. Policies reference those binding IDs, not the registration `name`. The gateway rejects duplicate binding IDs across services and prevents external services from claiming the reserved `openshell/` namespace. +Each service implements the supervisor middleware gRPC contract and may expose multiple binding IDs through `Describe`. Policies reference those binding IDs, not the registration `name`. The gateway rejects duplicate binding IDs across services and prevents operator-run services from claiming the reserved `openshell/` namespace. The gateway connects to every registered service and validates `Describe` before it starts. The service must therefore be running before the gateway. Policy creation and full policy updates call `ValidateConfig`; an unavailable service or invalid middleware configuration rejects the policy before persistence. `max_body_bytes` is the operator limit for every binding exposed by the service. It must be greater than zero and no larger than each binding's advertised limit. OpenShell rejects an oversized value instead of silently clamping it. -External middleware is a local-development preview. The endpoint must use plaintext `http://`, and `allow_insecure = true` is required as an explicit acknowledgement that inspected request content is sent without transport encryption or peer authentication. TLS, authentication, health checks, and runtime registration are not supported. The endpoint must be reachable from both the gateway and sandbox supervisors; use `host.openshell.internal` or another shared address when both runtimes resolve it. +The service endpoint must use plaintext `http://`, and `allow_insecure = true` is required as an explicit acknowledgement that inspected request content is sent without transport encryption or peer authentication. TLS, authentication, health checks, and runtime registration are not supported. The endpoint must be reachable from both the gateway and sandbox supervisors; use `host.openshell.internal` or another shared address when both runtimes resolve it. `image_pull_policy` is intentionally not a shared gateway key. Kubernetes and Docker use `Always`, `IfNotPresent`, or `Never`. Podman uses `always`, `missing`, `never`, or `newer`. Set it inside the relevant driver table. diff --git a/docs/reference/policy-schema.mdx b/docs/reference/policy-schema.mdx index 69ff06dc3..9279b6f66 100644 --- a/docs/reference/policy-schema.mdx +++ b/docs/reference/policy-schema.mdx @@ -476,6 +476,10 @@ Identifies an executable that is permitted to use the associated endpoints. ## Network Middleware + +Supervisor middleware is a research preview. Its policy and service contracts may change without compatibility guarantees. Use it only to prototype and evaluate middleware, not for production or long-lived integrations. + + **Category:** Dynamic An ordered list of middleware configs selected after network and L7 policy admit an HTTP request. Middleware selection is independent of the network policy entry that admitted the request. Every matching config runs once in list order before provider credential injection. diff --git a/docs/sandboxes/policies.mdx b/docs/sandboxes/policies.mdx index 295acb64c..5280cc085 100644 --- a/docs/sandboxes/policies.mdx +++ b/docs/sandboxes/policies.mdx @@ -72,6 +72,10 @@ Raw streams are connection-scoped and outside L7 live-reload guarantees. This in ## Supervisor Middleware + +Supervisor middleware is a research preview. Its policy and service contracts may change without compatibility guarantees. Use it only to prototype and evaluate middleware, not for production or long-lived integrations. + + Supervisor middleware can inspect, deny, or replace admitted HTTP request bodies before provider credentials are injected. Middleware selection is independent of the `network_policies` rule that admitted the request: each `network_middlewares` entry matches the destination host through `endpoints.include` and `endpoints.exclude`. ```yaml @@ -88,7 +92,7 @@ network_middlewares: Matching entries run once each in top-level declaration order. Config names must be unique. Different config names may use the same implementation and run as distinct stages. `exclude` takes precedence over `include`. -`openshell/secrets` is built into the supervisor. External binding IDs must be registered by the gateway operator before a policy can reference them; see [External Supervisor Middleware](/reference/gateway-config#external-supervisor-middleware). The gateway calls the implementation's `ValidateConfig` before accepting the policy. +`openshell/secrets` is built into the supervisor. Operator-provided binding IDs must be registered before a policy can reference them; see [Supervisor Middleware Services](/reference/gateway-config#supervisor-middleware-services). The gateway calls the implementation's `ValidateConfig` before accepting the policy. `on_error` defaults to `fail_closed`. Use `fail_open` only when skipping a failed middleware is acceptable. Middleware applies only to HTTP traffic the supervisor can parse and inspect; policy validation rejects a required selector that can cover a `tls: skip` endpoint. diff --git a/proto/sandbox.proto b/proto/sandbox.proto index afec58723..644fd86cb 100644 --- a/proto/sandbox.proto +++ b/proto/sandbox.proto @@ -352,14 +352,14 @@ message GetSandboxConfigResponse { // Fingerprint for provider credential inputs attached to this sandbox. // Changes when attached provider names or attached provider records change. uint64 provider_env_revision = 8; - // Operator-registered external middleware services required by the effective - // policy. Built-in middleware is not included. - repeated ExternalMiddlewareService external_middleware = 9; + // Operator-registered supervisor middleware services required by the + // effective policy. Built-in middleware is not included. + repeated SupervisorMiddlewareService supervisor_middleware_services = 9; } -// Connection details for one operator-registered external middleware service. +// Connection details for one operator-registered supervisor middleware service. // V1 supports only explicitly enabled plaintext gRPC for local development. -message ExternalMiddlewareService { +message SupervisorMiddlewareService { // Operator-facing registration name used for diagnostics. string name = 1; // gRPC endpoint reachable from the sandbox supervisor. From 31662d8601ce8dcffd2dfe822e07966bac9d01f6 Mon Sep 17 00:00:00 2001 From: Piotr Mlocek Date: Mon, 29 Jun 2026 20:18:37 -0700 Subject: [PATCH 09/27] docs(supervisor-middleware): refine preview warning Signed-off-by: Piotr Mlocek --- docs/reference/gateway-config.mdx | 2 +- docs/reference/policy-schema.mdx | 2 +- docs/sandboxes/policies.mdx | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/reference/gateway-config.mdx b/docs/reference/gateway-config.mdx index 1b3554c20..7036238a2 100644 --- a/docs/reference/gateway-config.mdx +++ b/docs/reference/gateway-config.mdx @@ -151,7 +151,7 @@ Local Docker, Podman, and VM gateways can also set `[openshell.gateway.mtls_auth ## Supervisor Middleware Services -Supervisor middleware is a research preview. Its policy and service contracts may change without compatibility guarantees. Use it only to prototype and evaluate middleware, not for production or long-lived integrations. +Supervisor middleware is a research preview. Its policy and service contracts may change without compatibility guarantees. Use it only to prototype and evaluate middleware integrations. Register operator-run supervisor middleware services with one or more `[[openshell.gateway.middleware]]` entries. Registration is static and operator-owned; changing it requires restarting the gateway. diff --git a/docs/reference/policy-schema.mdx b/docs/reference/policy-schema.mdx index 9279b6f66..6ef48d869 100644 --- a/docs/reference/policy-schema.mdx +++ b/docs/reference/policy-schema.mdx @@ -477,7 +477,7 @@ Identifies an executable that is permitted to use the associated endpoints. ## Network Middleware -Supervisor middleware is a research preview. Its policy and service contracts may change without compatibility guarantees. Use it only to prototype and evaluate middleware, not for production or long-lived integrations. +Supervisor middleware is a research preview. Its policy and service contracts may change without compatibility guarantees. Use it only to prototype and evaluate middleware integrations. **Category:** Dynamic diff --git a/docs/sandboxes/policies.mdx b/docs/sandboxes/policies.mdx index 5280cc085..8e71440e5 100644 --- a/docs/sandboxes/policies.mdx +++ b/docs/sandboxes/policies.mdx @@ -73,7 +73,7 @@ Raw streams are connection-scoped and outside L7 live-reload guarantees. This in ## Supervisor Middleware -Supervisor middleware is a research preview. Its policy and service contracts may change without compatibility guarantees. Use it only to prototype and evaluate middleware, not for production or long-lived integrations. +Supervisor middleware is a research preview. Its policy and service contracts may change without compatibility guarantees. Use it only to prototype and evaluate middleware integrations. Supervisor middleware can inspect, deny, or replace admitted HTTP request bodies before provider credentials are injected. Middleware selection is independent of the `network_policies` rule that admitted the request: each `network_middlewares` entry matches the destination host through `endpoints.include` and `endpoints.exclude`. From 2304e2fcc28c6af267d9b60139b8fde0b3763b5b Mon Sep 17 00:00:00 2001 From: Piotr Mlocek Date: Tue, 30 Jun 2026 10:22:43 -0700 Subject: [PATCH 10/27] docs(extensibility): add supervisor middleware guide Signed-off-by: Piotr Mlocek --- docs/extensibility/supervisor-middleware.mdx | 141 +++++++++++++++++++ docs/index.yml | 2 + docs/reference/gateway-config.mdx | 6 +- docs/reference/policy-schema.mdx | 6 +- docs/sandboxes/policies.mdx | 8 +- 5 files changed, 150 insertions(+), 13 deletions(-) create mode 100644 docs/extensibility/supervisor-middleware.mdx diff --git a/docs/extensibility/supervisor-middleware.mdx b/docs/extensibility/supervisor-middleware.mdx new file mode 100644 index 000000000..320495e59 --- /dev/null +++ b/docs/extensibility/supervisor-middleware.mdx @@ -0,0 +1,141 @@ +--- +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +title: "Supervisor Middleware" +sidebar-title: "Supervisor Middleware" +description: "Configure and operate built-in and operator-run middleware for sandbox HTTP requests." +keywords: "Generative AI, Cybersecurity, AI Agents, Supervisor Middleware, Extensibility, Request Filtering" +--- + +Supervisor middleware adds ordered request-processing stages to allowed HTTP egress. Middleware runs after network and L7 policy admit a request and before OpenShell injects provider credentials. A stage can allow or deny the request, replace its body, add approved headers, and report audit-safe findings. + +Middleware selection is independent of the network policy rule that admitted the request. OpenShell matches middleware by destination host, so the same middleware applies consistently across broad, specific, user-authored, and provider-derived network policies. + +## Request Flow + +For each inspected HTTP request, the supervisor: + +1. Evaluates network and L7 policy. +2. Selects middleware whose host selectors match the admitted destination. +3. Buffers the request body using the smallest body limit in the selected chain. +4. Runs matching middleware in policy declaration order. +5. Applies allowed transformations, injects provider credentials, and forwards the request. + +Middleware receives the request before credential injection. Operator-run services cannot inspect OpenShell-managed credentials. + +## Choose a Middleware Type + +| Type | Registration | Body limit | Deployment | +| --- | --- | --- | --- | +| Built-in | None | Defined by OpenShell | Runs inside the supervisor | +| Operator-run service | Required in gateway TOML | Set by the operator, up to the service capability | Runs as a separate service reachable by the gateway and supervisors | + +`openshell/secrets` is the built-in middleware currently available. It identifies common secret patterns in UTF-8 request bodies and replaces matched values before the request leaves the sandbox. + +Operator-run services expose one or more binding IDs. Policies reference a binding ID, such as `example/content-guard`, rather than the gateway registration name. + +## Register a Middleware Service + +Start an operator-run service before starting the gateway, then add a registration to the local gateway TOML: + +```toml +[[openshell.gateway.middleware]] +name = "local-content-guard" +endpoint = "http://host.openshell.internal:50051" +allow_insecure = true +max_body_bytes = 262144 +``` + +| Field | Description | +| --- | --- | +| `name` | Operator-facing registration name used in diagnostics. Policies do not reference this value. | +| `endpoint` | Service address reachable from both the gateway and sandbox supervisors. | +| `allow_insecure` | Required acknowledgement for the currently supported plaintext endpoint. | +| `max_body_bytes` | Operator limit applied to every binding exposed by the service. | + +The gateway connects to every registered service and verifies its capabilities before accepting traffic. Gateway startup fails when a service is unavailable, reports an invalid capability, or exposes a binding ID already owned by another service. Operator-run services cannot claim the reserved `openshell/` namespace. + +Registration is static. Restart the gateway after adding, removing, or changing a service. See [Gateway Configuration](/reference/gateway-config#supervisor-middleware-services) for the complete gateway TOML context. + +## Apply Middleware with Policy + +Add middleware configs to the top-level `network_middlewares` list: + +```yaml +network_middlewares: + - name: redact-secrets + middleware: openshell/secrets + config: + secrets: redact + on_error: fail_closed + endpoints: + include: ["*.example.com"] + exclude: ["trusted.example.com"] +``` + +Each config has a policy-local `name`, a built-in or operator-provided binding ID in `middleware`, implementation-owned `config`, failure behavior, and host selectors. + +`include` selects destination hosts. `exclude` takes precedence and removes hosts from that selection. Matching is case-insensitive and uses the same exact-host and DNS glob behavior as network policy endpoints. + +Matching configs run once each in top-level declaration order. Different config names may reference the same binding and run as separate stages. Config names must be unique. + +See [Policy Schema](/reference/policy-schema#network-middleware) for the complete field reference. + +## Configure Failure Behavior + +`on_error` controls what happens when middleware is unavailable, rejects its configuration, returns an invalid result, or exceeds a body limit. + +| Value | Behavior | +| --- | --- | +| `fail_closed` | Denies the request when the middleware stage fails. This is the default. | +| `fail_open` | Skips the failed stage and continues the request through the remaining chain. | + +Use `fail_open` only when bypassing the middleware preserves the intended security policy. OpenShell emits a detection finding when a failed stage is bypassed. + +An explicit deny decision always stops the chain and denies the request, regardless of `on_error`. + +## Set Body Limits + +Every middleware binding declares the largest request or replacement body it supports. + +- Built-in middleware uses its OpenShell-defined limit. +- Each operator-run registration sets `max_body_bytes` no higher than the service capability. +- A selected chain buffers using its smallest stage limit. +- The same per-stage limit applies to request bodies and replacement bodies. + +The gateway rejects a registration whose operator limit exceeds the service capability instead of silently clamping it. At request time, exceeding a selected stage's limit is a middleware failure and follows that config's `on_error` behavior. + +## Operate Middleware Services + +Plan startup and updates around these boundaries: + +- Start registered services before the gateway. The gateway validates every registration during startup. +- Keep service endpoints reachable from both the gateway and sandbox supervisors. The supervisors call operator-run services directly on the request path. +- Restart the gateway after changing registrations. +- Keep required services available before creating or updating policies. The gateway validates implementation-owned config before persisting a policy. +- Treat `fail_open` as an explicit availability-over-enforcement decision. + +When the effective sandbox configuration changes, a running supervisor validates the new service registry before installing it. If the reload fails, the supervisor keeps its last-known-good registry and emits a configuration failure event. + +## Observe Middleware + +Middleware activity is emitted through OpenShell's OCSF logging: + +- Each invocation records its policy-local middleware name, binding, decision, transformation state, and failure state. +- A bypass under `fail_open` emits a detection finding. +- A required stage that fails closed emits a high-severity detection finding. +- Findings include the service-provided type and label plus aggregate counts. Middleware services should keep those fields audit-safe and omit request content or matched values. +- Registry reload success and failure are emitted as configuration state changes. + +See [Logging](/observability/logging) for log access and [OCSF JSON Export](/observability/ocsf-json-export) for structured export. + +## Current Limitations + +- Middleware applies only to HTTP requests parsed by the supervisor. +- The supported operation and phase are `HttpRequest/pre_credentials`. +- Selection uses destination host include and exclude patterns. +- Required middleware cannot cover `tls: skip` endpoints because OpenShell cannot inspect that traffic. +- Operator-run services currently use explicitly enabled plaintext `http://` endpoints. +- TLS, service authentication, health checks, and runtime registration are not available. + +For a runnable operator workflow, see the [content guard example](https://github.com/NVIDIA/OpenShell/tree/main/examples/supervisor-middleware-content-guard). diff --git a/docs/index.yml b/docs/index.yml index b2443e4af..45db451a7 100644 --- a/docs/index.yml +++ b/docs/index.yml @@ -19,6 +19,8 @@ navigation: title: "Manage OpenShell" - folder: providers title: "Providers" +- folder: extensibility + title: "Extensibility" - folder: observability title: "Observability" - folder: kubernetes diff --git a/docs/reference/gateway-config.mdx b/docs/reference/gateway-config.mdx index 7036238a2..cad7a4e88 100644 --- a/docs/reference/gateway-config.mdx +++ b/docs/reference/gateway-config.mdx @@ -150,10 +150,6 @@ Local Docker, Podman, and VM gateways can also set `[openshell.gateway.mtls_auth ## Supervisor Middleware Services - -Supervisor middleware is a research preview. Its policy and service contracts may change without compatibility guarantees. Use it only to prototype and evaluate middleware integrations. - - Register operator-run supervisor middleware services with one or more `[[openshell.gateway.middleware]]` entries. Registration is static and operator-owned; changing it requires restarting the gateway. ```toml @@ -172,6 +168,8 @@ The gateway connects to every registered service and validates `Describe` before The service endpoint must use plaintext `http://`, and `allow_insecure = true` is required as an explicit acknowledgement that inspected request content is sent without transport encryption or peer authentication. TLS, authentication, health checks, and runtime registration are not supported. The endpoint must be reachable from both the gateway and sandbox supervisors; use `host.openshell.internal` or another shared address when both runtimes resolve it. +See [Supervisor Middleware](/extensibility/supervisor-middleware) for selection, failure, body-limit, and operational guidance. + `image_pull_policy` is intentionally not a shared gateway key. Kubernetes and Docker use `Always`, `IfNotPresent`, or `Never`. Podman uses `always`, `missing`, `never`, or `newer`. Set it inside the relevant driver table. ## Driver References diff --git a/docs/reference/policy-schema.mdx b/docs/reference/policy-schema.mdx index 6ef48d869..e36c535be 100644 --- a/docs/reference/policy-schema.mdx +++ b/docs/reference/policy-schema.mdx @@ -476,10 +476,6 @@ Identifies an executable that is permitted to use the associated endpoints. ## Network Middleware - -Supervisor middleware is a research preview. Its policy and service contracts may change without compatibility guarantees. Use it only to prototype and evaluate middleware integrations. - - **Category:** Dynamic An ordered list of middleware configs selected after network and L7 policy admit an HTTP request. Middleware selection is independent of the network policy entry that admitted the request. Every matching config runs once in list order before provider credential injection. @@ -506,6 +502,8 @@ network_middlewares: Host selectors use the same case-insensitive exact and DNS glob semantics as network endpoints. Middleware runs only on HTTP requests the supervisor parses. A selector that can require middleware on a `tls: skip` endpoint is rejected because OpenShell cannot inspect that traffic. +See [Supervisor Middleware](/extensibility/supervisor-middleware) for registration, failure behavior, body limits, and operational guidance. + ## Full Example The following policy grants read-only GitHub API access and npm registry access: diff --git a/docs/sandboxes/policies.mdx b/docs/sandboxes/policies.mdx index 8e71440e5..1353fd640 100644 --- a/docs/sandboxes/policies.mdx +++ b/docs/sandboxes/policies.mdx @@ -72,10 +72,6 @@ Raw streams are connection-scoped and outside L7 live-reload guarantees. This in ## Supervisor Middleware - -Supervisor middleware is a research preview. Its policy and service contracts may change without compatibility guarantees. Use it only to prototype and evaluate middleware integrations. - - Supervisor middleware can inspect, deny, or replace admitted HTTP request bodies before provider credentials are injected. Middleware selection is independent of the `network_policies` rule that admitted the request: each `network_middlewares` entry matches the destination host through `endpoints.include` and `endpoints.exclude`. ```yaml @@ -92,10 +88,12 @@ network_middlewares: Matching entries run once each in top-level declaration order. Config names must be unique. Different config names may use the same implementation and run as distinct stages. `exclude` takes precedence over `include`. -`openshell/secrets` is built into the supervisor. Operator-provided binding IDs must be registered before a policy can reference them; see [Supervisor Middleware Services](/reference/gateway-config#supervisor-middleware-services). The gateway calls the implementation's `ValidateConfig` before accepting the policy. +`openshell/secrets` is built into the supervisor. Operator-provided binding IDs must be registered before a policy can reference them. The gateway validates implementation-owned config before accepting the policy. `on_error` defaults to `fail_closed`. Use `fail_open` only when skipping a failed middleware is acceptable. Middleware applies only to HTTP traffic the supervisor can parse and inspect; policy validation rejects a required selector that can cover a `tls: skip` endpoint. +See [Supervisor Middleware](/extensibility/supervisor-middleware) for registration, chain ordering, body limits, failure behavior, and operations. + ## Baseline Filesystem Paths When a sandbox runs in proxy mode (the default), OpenShell automatically adds baseline filesystem paths required for the sandbox child process to function: `/usr`, `/lib`, `/etc`, `/var/log` (read-only) and `/sandbox`, `/tmp` (read-write). Paths like `/app` are included in the baseline set but are only added if they exist in the container image. From 97bef95a7a4b44dd97cad1a7b7a7cb653e168179 Mon Sep 17 00:00:00 2001 From: Piotr Mlocek Date: Tue, 30 Jun 2026 12:41:09 -0700 Subject: [PATCH 11/27] fix(server): remove stale middleware import Signed-off-by: Piotr Mlocek --- crates/openshell-server/src/lib.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/crates/openshell-server/src/lib.rs b/crates/openshell-server/src/lib.rs index bd974d4b4..86afe4ef4 100644 --- a/crates/openshell-server/src/lib.rs +++ b/crates/openshell-server/src/lib.rs @@ -55,7 +55,6 @@ mod ws_tunnel; use metrics_exporter_prometheus::PrometheusBuilder; use openshell_core::{ComputeDriverKind, Config, Error, Result}; use openshell_supervisor_middleware::MiddlewareRegistry; -use serde::Deserialize; use std::collections::HashMap; use std::io::ErrorKind; use std::net::SocketAddr; From ab2daba91ff79e10f82dcc52764746bf239be91e Mon Sep 17 00:00:00 2001 From: Piotr Mlocek Date: Tue, 30 Jun 2026 13:26:06 -0700 Subject: [PATCH 12/27] fix(network): remove needless test struct updates Signed-off-by: Piotr Mlocek --- crates/openshell-supervisor-network/src/opa.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/crates/openshell-supervisor-network/src/opa.rs b/crates/openshell-supervisor-network/src/opa.rs index 9e1427dcd..f0aea9c38 100644 --- a/crates/openshell-supervisor-network/src/opa.rs +++ b/crates/openshell-supervisor-network/src/opa.rs @@ -6698,7 +6698,6 @@ network_policies: path: link_path, ..Default::default() }], - ..Default::default() }, ); let proto = ProtoSandboxPolicy { @@ -6777,7 +6776,6 @@ network_policies: path: link_path, ..Default::default() }], - ..Default::default() }, ); let proto = ProtoSandboxPolicy { From a820dd289778c5f39df6294e3c5906a8463a81a3 Mon Sep 17 00:00:00 2001 From: Piotr Mlocek Date: Tue, 30 Jun 2026 13:48:26 -0700 Subject: [PATCH 13/27] fix(middleware): avoid enabling core telemetry Signed-off-by: Piotr Mlocek --- crates/openshell-supervisor-middleware/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/openshell-supervisor-middleware/Cargo.toml b/crates/openshell-supervisor-middleware/Cargo.toml index e5e53618d..0af6e70c2 100644 --- a/crates/openshell-supervisor-middleware/Cargo.toml +++ b/crates/openshell-supervisor-middleware/Cargo.toml @@ -11,7 +11,7 @@ repository.workspace = true rust-version.workspace = true [dependencies] -openshell-core = { path = "../openshell-core" } +openshell-core = { path = "../openshell-core", default-features = false } miette = { workspace = true } prost-types = { workspace = true } From 49424a074f0ec41138c09e6f9965ab5b402a26bd Mon Sep 17 00:00:00 2001 From: Piotr Mlocek Date: Tue, 30 Jun 2026 14:47:35 -0700 Subject: [PATCH 14/27] refactor(supervisor-middleware): simplify service endpoints Signed-off-by: Piotr Mlocek --- crates/openshell-server/src/config_file.rs | 14 ++-- crates/openshell-server/src/grpc/policy.rs | 3 +- .../Cargo.toml | 2 +- .../src/lib.rs | 39 ++++++----- .../src/remote.rs | 23 +++++-- docs/extensibility/supervisor-middleware.mdx | 10 ++- docs/reference/gateway-config.mdx | 12 ++-- proto/middleware.proto | 66 +++++++++++++++++++ proto/sandbox.proto | 10 +-- 9 files changed, 127 insertions(+), 52 deletions(-) diff --git a/crates/openshell-server/src/config_file.rs b/crates/openshell-server/src/config_file.rs index 4b0fbc919..7306c80e7 100644 --- a/crates/openshell-server/src/config_file.rs +++ b/crates/openshell-server/src/config_file.rs @@ -174,10 +174,7 @@ pub struct MiddlewareServiceFileConfig { /// Operator-facing name used for diagnostics. pub name: String, /// Plaintext gRPC endpoint reachable by the gateway and supervisors. - pub endpoint: String, - /// Required explicit opt-in to the local-development-only insecure mode. - #[serde(default)] - pub allow_insecure: bool, + pub grpc_endpoint: String, /// Operator-owned body limit for every binding exposed by this service. pub max_body_bytes: u64, } @@ -186,8 +183,7 @@ impl From<&MiddlewareServiceFileConfig> for SupervisorMiddlewareService { fn from(config: &MiddlewareServiceFileConfig) -> Self { Self { name: config.name.clone(), - endpoint: config.endpoint.clone(), - allow_insecure: config.allow_insecure, + grpc_endpoint: config.grpc_endpoint.clone(), max_body_bytes: config.max_body_bytes, } } @@ -439,8 +435,7 @@ allow_unauthenticated_users = true let toml = r#" [[openshell.gateway.middleware]] name = "local-guard" -endpoint = "http://127.0.0.1:50051" -allow_insecure = true +grpc_endpoint = "http://127.0.0.1:50051" max_body_bytes = 262144 "#; let tmp = write_tmp(toml); @@ -449,8 +444,7 @@ max_body_bytes = 262144 file.openshell.gateway.middleware, vec![MiddlewareServiceFileConfig { name: "local-guard".into(), - endpoint: "http://127.0.0.1:50051".into(), - allow_insecure: true, + grpc_endpoint: "http://127.0.0.1:50051".into(), max_body_bytes: 262_144, }] ); diff --git a/crates/openshell-server/src/grpc/policy.rs b/crates/openshell-server/src/grpc/policy.rs index 9587e7295..97d459813 100644 --- a/crates/openshell-server/src/grpc/policy.rs +++ b/crates/openshell-server/src/grpc/policy.rs @@ -9331,8 +9331,7 @@ mod tests { let settings = HashMap::new(); let service = openshell_core::proto::SupervisorMiddlewareService { name: "local-guard".into(), - endpoint: "http://127.0.0.1:50051".into(), - allow_insecure: true, + grpc_endpoint: "http://127.0.0.1:50051".into(), max_body_bytes: 1024, }; diff --git a/crates/openshell-supervisor-middleware/Cargo.toml b/crates/openshell-supervisor-middleware/Cargo.toml index 0af6e70c2..f36fc854b 100644 --- a/crates/openshell-supervisor-middleware/Cargo.toml +++ b/crates/openshell-supervisor-middleware/Cargo.toml @@ -17,7 +17,7 @@ miette = { workspace = true } prost-types = { workspace = true } regex = { workspace = true } tokio = { workspace = true } -tonic = { workspace = true, features = ["channel", "server"] } +tonic = { workspace = true, features = ["channel", "server", "tls-native-roots"] } [dev-dependencies] tokio-stream = { workspace = true, features = ["net"] } diff --git a/crates/openshell-supervisor-middleware/src/lib.rs b/crates/openshell-supervisor-middleware/src/lib.rs index ebc5817e4..00324d75f 100644 --- a/crates/openshell-supervisor-middleware/src/lib.rs +++ b/crates/openshell-supervisor-middleware/src/lib.rs @@ -246,15 +246,11 @@ fn validate_registration(registration: &SupervisorMiddlewareService) -> Result<( "supervisor middleware registration name cannot be empty" )); } - if !registration.allow_insecure { - return Err(miette!( - "middleware registration '{}' must set allow_insecure = true; TLS is not supported in V1", - registration.name - )); - } - if !registration.endpoint.starts_with("http://") { + if !registration.grpc_endpoint.starts_with("http://") + && !registration.grpc_endpoint.starts_with("https://") + { return Err(miette!( - "middleware registration '{}' endpoint must use http:// in the local-development-only V1", + "middleware registration '{}' grpc_endpoint must use http:// or https://", registration.name )); } @@ -365,7 +361,7 @@ impl MiddlewareRegistry { let service = Arc::new( remote::RemoteMiddlewareService::connect( ®istration.name, - ®istration.endpoint, + ®istration.grpc_endpoint, ) .await?, ); @@ -1129,8 +1125,7 @@ mod tests { fn external_registration(max_body_bytes: u64) -> SupervisorMiddlewareService { SupervisorMiddlewareService { name: "local-guard-service".into(), - endpoint: "http://127.0.0.1:50051".into(), - allow_insecure: true, + grpc_endpoint: "http://127.0.0.1:50051".into(), max_body_bytes, } } @@ -1267,11 +1262,23 @@ mod tests { } #[test] - fn external_registration_requires_explicit_insecure_opt_in() { + fn external_registration_accepts_http_and_https_grpc_endpoints() { + for grpc_endpoint in [ + "http://127.0.0.1:50051", + "https://middleware.example.com:443", + ] { + let mut registration = external_registration(4096); + registration.grpc_endpoint = grpc_endpoint.into(); + validate_registration(®istration).expect("supported gRPC endpoint scheme"); + } + } + + #[test] + fn external_registration_rejects_unsupported_grpc_endpoint_scheme() { let mut registration = external_registration(4096); - registration.allow_insecure = false; - let error = validate_registration(®istration).expect_err("opt-in required"); - assert!(error.to_string().contains("allow_insecure")); + registration.grpc_endpoint = "ftp://middleware.example.com".into(); + let error = validate_registration(®istration).expect_err("unsupported scheme"); + assert!(error.to_string().contains("http:// or https://")); } #[tokio::test] @@ -1293,7 +1300,7 @@ mod tests { let server_task = tokio::spawn(server); let mut registration = external_registration(1024); - registration.endpoint = format!("http://{address}"); + registration.grpc_endpoint = format!("http://{address}"); let registry = MiddlewareRegistry::connect_services(vec![registration.clone()]) .await .expect("connect external middleware"); diff --git a/crates/openshell-supervisor-middleware/src/remote.rs b/crates/openshell-supervisor-middleware/src/remote.rs index dd147788b..7645ed811 100644 --- a/crates/openshell-supervisor-middleware/src/remote.rs +++ b/crates/openshell-supervisor-middleware/src/remote.rs @@ -10,7 +10,7 @@ use openshell_core::proto::{ HttpRequestEvaluation, HttpRequestResult, MiddlewareManifest, ValidateConfigRequest, ValidateConfigResponse, }; -use tonic::transport::{Channel, Endpoint}; +use tonic::transport::{Channel, ClientTlsConfig, Endpoint}; use tonic::{Request, Response, Status}; const CONNECT_TIMEOUT: Duration = Duration::from_secs(5); @@ -23,19 +23,30 @@ pub struct RemoteMiddlewareService { } impl RemoteMiddlewareService { - pub async fn connect(registration_name: &str, endpoint: &str) -> Result { - let channel = Endpoint::from_shared(endpoint.to_string()) + pub async fn connect(registration_name: &str, grpc_endpoint: &str) -> Result { + let mut endpoint = Endpoint::from_shared(grpc_endpoint.to_string()) .into_diagnostic() .wrap_err_with(|| { - format!("middleware registration '{registration_name}' has an invalid endpoint") - })? + format!( + "middleware registration '{registration_name}' has an invalid grpc_endpoint" + ) + })?; + if grpc_endpoint.starts_with("https://") { + endpoint = endpoint + .tls_config(ClientTlsConfig::new().with_enabled_roots()) + .into_diagnostic() + .wrap_err_with(|| { + format!("middleware registration '{registration_name}' could not configure TLS") + })?; + } + let channel = endpoint .connect_timeout(CONNECT_TIMEOUT) .connect() .await .into_diagnostic() .wrap_err_with(|| { format!( - "middleware registration '{registration_name}' could not connect to {endpoint}" + "middleware registration '{registration_name}' could not connect to {grpc_endpoint}" ) })?; Ok(Self { diff --git a/docs/extensibility/supervisor-middleware.mdx b/docs/extensibility/supervisor-middleware.mdx index 320495e59..0ebd518bd 100644 --- a/docs/extensibility/supervisor-middleware.mdx +++ b/docs/extensibility/supervisor-middleware.mdx @@ -41,16 +41,14 @@ Start an operator-run service before starting the gateway, then add a registrati ```toml [[openshell.gateway.middleware]] name = "local-content-guard" -endpoint = "http://host.openshell.internal:50051" -allow_insecure = true +grpc_endpoint = "http://host.openshell.internal:50051" max_body_bytes = 262144 ``` | Field | Description | | --- | --- | | `name` | Operator-facing registration name used in diagnostics. Policies do not reference this value. | -| `endpoint` | Service address reachable from both the gateway and sandbox supervisors. | -| `allow_insecure` | Required acknowledgement for the currently supported plaintext endpoint. | +| `grpc_endpoint` | Service address reachable from both the gateway and sandbox supervisors. Supports plaintext `http://` and TLS `https://` with platform trust roots. | | `max_body_bytes` | Operator limit applied to every binding exposed by the service. | The gateway connects to every registered service and verifies its capabilities before accepting traffic. Gateway startup fails when a service is unavailable, reports an invalid capability, or exposes a binding ID already owned by another service. Operator-run services cannot claim the reserved `openshell/` namespace. @@ -135,7 +133,7 @@ See [Logging](/observability/logging) for log access and [OCSF JSON Export](/obs - The supported operation and phase are `HttpRequest/pre_credentials`. - Selection uses destination host include and exclude patterns. - Required middleware cannot cover `tls: skip` endpoints because OpenShell cannot inspect that traffic. -- Operator-run services currently use explicitly enabled plaintext `http://` endpoints. -- TLS, service authentication, health checks, and runtime registration are not available. +- Operator-run services support plaintext `http://` and TLS `https://` endpoints. HTTPS certificates must chain to a CA in the platform trust store. +- Custom trust roots, client authentication, health checks, and runtime registration are not available. For a runnable operator workflow, see the [content guard example](https://github.com/NVIDIA/OpenShell/tree/main/examples/supervisor-middleware-content-guard). diff --git a/docs/reference/gateway-config.mdx b/docs/reference/gateway-config.mdx index cad7a4e88..64540e512 100644 --- a/docs/reference/gateway-config.mdx +++ b/docs/reference/gateway-config.mdx @@ -103,12 +103,11 @@ guest_tls_key = "/etc/openshell/certs/client-key.pem" grpc_rate_limit_requests = 120 grpc_rate_limit_window_seconds = 60 -# Local-development-only external supervisor middleware. The endpoint must be -# reachable from both the gateway and sandbox supervisors. +# Operator-run supervisor middleware. The gRPC endpoint must be reachable from +# both the gateway and sandbox supervisors. [[openshell.gateway.middleware]] name = "local-content-guard" -endpoint = "http://host.openshell.internal:50051" -allow_insecure = true +grpc_endpoint = "http://host.openshell.internal:50051" max_body_bytes = 262144 # Gateway listener TLS (distinct from the per-driver guest_tls_*). @@ -155,8 +154,7 @@ Register operator-run supervisor middleware services with one or more `[[openshe ```toml [[openshell.gateway.middleware]] name = "local-content-guard" -endpoint = "http://host.openshell.internal:50051" -allow_insecure = true +grpc_endpoint = "http://host.openshell.internal:50051" max_body_bytes = 262144 ``` @@ -166,7 +164,7 @@ The gateway connects to every registered service and validates `Describe` before `max_body_bytes` is the operator limit for every binding exposed by the service. It must be greater than zero and no larger than each binding's advertised limit. OpenShell rejects an oversized value instead of silently clamping it. -The service endpoint must use plaintext `http://`, and `allow_insecure = true` is required as an explicit acknowledgement that inspected request content is sent without transport encryption or peer authentication. TLS, authentication, health checks, and runtime registration are not supported. The endpoint must be reachable from both the gateway and sandbox supervisors; use `host.openshell.internal` or another shared address when both runtimes resolve it. +The service `grpc_endpoint` currently supports plaintext `http://` and TLS `https://` using the platform trust store. Custom trust roots, client authentication, health checks, and runtime registration are not currently supported. The endpoint must be reachable from both the gateway and sandbox supervisors; use `host.openshell.internal` or another shared address that can be resolved in both places. See [Supervisor Middleware](/extensibility/supervisor-middleware) for selection, failure, body-limit, and operational guidance. diff --git a/proto/middleware.proto b/proto/middleware.proto index 2944227d8..9b988b930 100644 --- a/proto/middleware.proto +++ b/proto/middleware.proto @@ -8,90 +8,156 @@ package openshell.middleware.v1; import "google/protobuf/empty.proto"; import "google/protobuf/struct.proto"; +// SupervisorMiddleware lets an operator-run service inspect and transform +// sandbox HTTP egress before OpenShell injects credentials. service SupervisorMiddleware { + // Describe returns the service manifest and declared bindings. rpc Describe(google.protobuf.Empty) returns (MiddlewareManifest); + + // ValidateConfig checks service-specific configuration for one binding. rpc ValidateConfig(ValidateConfigRequest) returns (ValidateConfigResponse); + + // EvaluateHttpRequest returns an allow, deny, or mutation decision for one + // buffered HTTP request. rpc EvaluateHttpRequest(HttpRequestEvaluation) returns (HttpRequestResult); } +// MiddlewareManifest describes one service and the bindings it exposes. message MiddlewareManifest { + // Middleware protocol version implemented by the service. string api_version = 1; + // Human-readable service name used for diagnostics. string name = 2; + // Service-defined version string used for diagnostics. string service_version = 3; + // Bindings exposed by this service. repeated MiddlewareBinding bindings = 4; } +// MiddlewareBinding declares one operation and phase supported by a service. message MiddlewareBinding { + // Stable binding id used by policy configuration and audit logs. string id = 1; + // Supported operation name. V1 supports "HttpRequest". string operation = 2; + // Supported evaluation phase. V1 supports "pre_credentials". string phase = 3; // Maximum request or replacement body this binding can process. uint64 max_body_bytes = 4; } +// ValidateConfigRequest contains one policy configuration to validate. message ValidateConfigRequest { + // Middleware protocol version selected by OpenShell. string api_version = 1; + // Manifest binding id associated with this configuration. string binding_id = 2; + // Service-specific policy configuration. google.protobuf.Struct config = 3; } +// ValidateConfigResponse reports whether a policy configuration is accepted. message ValidateConfigResponse { + // True when the service accepts the configuration. bool valid = 1; + // Human-readable validation failure reason. Empty when valid is true. string reason = 2; } +// HttpRequestEvaluation contains one buffered HTTP request to evaluate. message HttpRequestEvaluation { + // Middleware protocol version selected by OpenShell. string api_version = 1; + // Manifest binding id selected for this evaluation. string binding_id = 2; + // Evaluation phase selected for this request. string phase = 3; + // Sandbox and request identity available to the supervisor. RequestContext context = 4; + // Validated service-specific policy configuration. google.protobuf.Struct config = 5; + // Destination and HTTP request target. HttpRequestTarget target = 6; + // HTTP request headers before OpenShell injects credentials. map headers = 7; + // Buffered request body. Empty for a bodyless request. bytes body = 8; } +// RequestContext identifies the sandbox request being evaluated. message RequestContext { + // Request id used to correlate middleware and supervisor logs. string request_id = 1; + // Sandbox id that originated the request. string sandbox_id = 2; + // Workload process that originated the request, when available. Process originating_process = 3; } +// HttpRequestTarget describes the admitted HTTP destination and request target. message HttpRequestTarget { + // Request scheme, such as "http" or "https". string scheme = 1; + // Destination hostname selected by network policy. string host = 2; + // Destination TCP port. uint32 port = 3; + // HTTP request method. string method = 4; + // Request path without the query string. string path = 5; + // Raw request query string without the leading question mark. string query = 6; } +// Process identifies a workload process and its executable ancestry. message Process { + // Executable path for the originating process. string binary = 1; + // Process id within the sandbox. uint32 pid = 2; + // Executable paths for ancestor processes, nearest parent first. repeated string ancestors = 3; } +// Decision controls whether OpenShell continues processing the request. enum Decision { + // Invalid response value handled according to the policy failure mode. DECISION_UNSPECIFIED = 0; + // Continue processing the request and apply any returned mutations. DECISION_ALLOW = 1; + // Deny the request before credentials are injected or data is sent upstream. DECISION_DENY = 2; } +// Finding is an audit-safe observation produced during evaluation. message Finding { + // Stable, service-defined finding type. string type = 1; + // Human-readable finding label that does not contain request content. string label = 2; + // Number of matching observations represented by this finding. uint32 count = 3; + // Service-defined confidence level. string confidence = 4; + // Service-defined severity level. string severity = 5; } +// HttpRequestResult contains the decision and optional request mutations. message HttpRequestResult { + // Allow or deny decision for this request. Decision decision = 1; + // Human-readable reason used for diagnostics and denied responses. string reason = 2; + // Replacement request body when has_body is true. bytes body = 3; + // True when body should replace the request body, including with an empty body. bool has_body = 4; + // Request headers to add before forwarding. Protected headers are rejected. map add_headers = 5; + // Audit-safe findings produced during evaluation. repeated Finding findings = 6; + // Non-secret service-defined metadata included in diagnostics. map metadata = 7; } diff --git a/proto/sandbox.proto b/proto/sandbox.proto index 644fd86cb..c4018573d 100644 --- a/proto/sandbox.proto +++ b/proto/sandbox.proto @@ -79,8 +79,12 @@ message NetworkMiddlewareConfig { MiddlewareEndpointSelector endpoints = 5; } +// Host selector controlling which admitted destinations use a middleware config. message MiddlewareEndpointSelector { + // Exact host or DNS glob patterns included in the selection. repeated string include = 1; + // Exact host or DNS glob patterns removed from the selection. + // Exclusions take precedence over inclusions. repeated string exclude = 2; } @@ -358,14 +362,12 @@ message GetSandboxConfigResponse { } // Connection details for one operator-registered supervisor middleware service. -// V1 supports only explicitly enabled plaintext gRPC for local development. +// V1 supports plaintext and server-authenticated TLS gRPC. message SupervisorMiddlewareService { // Operator-facing registration name used for diagnostics. string name = 1; // gRPC endpoint reachable from the sandbox supervisor. - string endpoint = 2; - // Explicit acknowledgement that request content is sent without TLS. - bool allow_insecure = 3; + string grpc_endpoint = 2; // Operator-owned body limit applied to every binding exposed by the service. uint64 max_body_bytes = 4; } From 3e69c5f4ac50977af3df89117aa07b357801427a Mon Sep 17 00:00:00 2001 From: Piotr Mlocek Date: Tue, 30 Jun 2026 17:31:09 -0700 Subject: [PATCH 15/27] fix(supervisor-middleware): keep sandbox startup resilient to middleware outages An unreachable operator-registered middleware service previously aborted sandbox startup via a hard error in load_policy, contradicting the per-request on_error contract and the resilient live-reload path. Retry the initial connect and, on failure, degrade to the built-in registry so matched requests are governed by each config's on_error (deny for fail_closed, allow for fail_open) instead of blocking the whole sandbox. The policy poll loop now reconciles the registry on every poll while an install is pending, so a recovered service is adopted without waiting for a config change; a failed reconcile also no longer blocks unrelated policy updates. Signed-off-by: Piotr Mlocek --- crates/openshell-sandbox/src/lib.rs | 159 +++++++++++++++++++--------- 1 file changed, 109 insertions(+), 50 deletions(-) diff --git a/crates/openshell-sandbox/src/lib.rs b/crates/openshell-sandbox/src/lib.rs index 5e640d73f..25750ee14 100644 --- a/crates/openshell-sandbox/src/lib.rs +++ b/crates/openshell-sandbox/src/lib.rs @@ -128,7 +128,7 @@ pub async fn run_sandbox( // Load policy and initialize OPA engine let openshell_endpoint_for_proxy = openshell_endpoint.clone(); let sandbox_name_for_agg = sandbox.clone(); - let (mut policy, opa_engine, retained_proto) = load_policy( + let (mut policy, opa_engine, retained_proto, middleware_install_pending) = load_policy( sandbox_id.clone(), sandbox, openshell_endpoint.clone(), @@ -423,6 +423,7 @@ pub async fn run_sandbox( ocsf_enabled: poll_ocsf_enabled, provider_credentials: poll_provider_credentials, policy_local_ctx: poll_policy_local, + middleware_install_pending, }; tokio::spawn(async move { @@ -1370,6 +1371,11 @@ async fn load_policy( SandboxPolicy, Option>, Option, + // True when operator-registered middleware could not be connected at + // startup and the engine kept the built-in registry. The policy poll loop + // retries the install so a recovered service is picked up without a config + // change. + bool, )> { // File mode: load OPA engine from rego rules + YAML data (dev override) if let (Some(policy_file), Some(data_file)) = (&policy_rules, &policy_data) { @@ -1399,7 +1405,8 @@ async fn load_policy( process: config.process, }; enrich_sandbox_baseline_paths(&mut policy); - return Ok((policy, Some(Arc::new(engine)), None)); + // File mode has no operator-registered middleware to connect. + return Ok((policy, Some(Arc::new(engine)), None, false)); } // gRPC mode: fetch typed proto policy, construct OPA engine from baked rules + proto data @@ -1482,16 +1489,45 @@ async fn load_policy( // engine is rebuilt with the real PID for symlink resolution. info!("Creating OPA engine from proto policy data"); let engine = OpaEngine::from_proto(&proto_policy)?; - let middleware_registry = + // Connect operator-registered middleware services. A connect/describe + // failure must not abort sandbox startup: unlike the previous hard + // failure, we degrade to the built-in registry and let each request's + // `on_error` policy govern matched traffic (deny for fail_closed, allow + // for fail_open). The policy poll loop retries the install so a + // recovered service is picked up without a config change. This mirrors + // the resilient live-reload path. + let middleware_services = sandbox_config.supervisor_middleware_services.clone(); + let middleware_install_pending = match grpc_retry("Middleware connect", || { openshell_supervisor_middleware::MiddlewareRegistry::connect_services( - sandbox_config.supervisor_middleware_services, + middleware_services.clone(), ) - .await?; - engine.replace_middleware_registry(middleware_registry)?; + }) + .await + .and_then(|registry| engine.replace_middleware_registry(registry)) + { + Ok(()) => false, + Err(error) => { + ocsf_emit!( + ConfigStateChangeBuilder::new(ocsf_ctx()) + .severity(SeverityId::Medium) + .status(StatusId::Failure) + .state(StateId::Other, "degraded") + .unmapped( + "supervisor_middleware_service_count", + serde_json::json!(middleware_services.len()) + ) + .message(format!( + "Supervisor middleware connect failed at startup; continuing with built-in middleware only, per-request on_error governs matched requests [error:{error}]" + )) + .build() + ); + true + } + }; let opa_engine = Some(Arc::new(engine)); let policy = SandboxPolicy::try_from(proto_policy.clone())?; - return Ok((policy, opa_engine, Some(proto_policy))); + return Ok((policy, opa_engine, Some(proto_policy), middleware_install_pending)); } // No policy source available @@ -1627,6 +1663,10 @@ struct PolicyPollLoopContext { ocsf_enabled: Arc, provider_credentials: ProviderCredentialState, policy_local_ctx: Option>, + /// True when `load_policy` degraded to the built-in middleware registry + /// because operator services could not be connected at startup. The poll + /// loop retries the install until it succeeds. + middleware_install_pending: bool, } async fn run_policy_poll_loop(ctx: PolicyPollLoopContext) -> Result<()> { @@ -1639,6 +1679,10 @@ async fn run_policy_poll_loop(ctx: PolicyPollLoopContext) -> Result<()> { let mut current_provider_env_revision: u64 = ctx.provider_credentials.snapshot().revision; let mut current_policy_hash = String::new(); let mut current_middleware_services = Vec::new(); + // Set when a middleware install is outstanding (degraded at startup or a + // failed reload). Drives a retry on every poll, independent of the config + // revision, so a recovered operator service is picked up promptly. + let mut middleware_sync_pending = ctx.middleware_install_pending; let mut current_settings: std::collections::HashMap< String, openshell_core::proto::EffectiveSetting, @@ -1674,14 +1718,70 @@ async fn run_policy_poll_loop(ctx: PolicyPollLoopContext) -> Result<()> { } }; + // Reconcile the supervisor middleware registry before evaluating the + // rest of the config. This runs independently of the config revision so + // an install that degraded at startup (or failed on an earlier poll) is + // retried here, letting a recovered operator service be picked up + // without waiting for a policy change. A failure keeps the + // last-known-good registry; the request path stays governed by each + // middleware's `on_error` policy, and a config change is still applied + // below rather than being blocked by the middleware outage. + if middleware_sync_pending + || result.supervisor_middleware_services != current_middleware_services + { + match openshell_supervisor_middleware::MiddlewareRegistry::connect_services( + result.supervisor_middleware_services.clone(), + ) + .await + .and_then(|registry| ctx.opa_engine.replace_middleware_registry(registry)) + { + Ok(()) => { + current_middleware_services = result.supervisor_middleware_services.clone(); + middleware_sync_pending = false; + ocsf_emit!( + ConfigStateChangeBuilder::new(ocsf_ctx()) + .severity(SeverityId::Informational) + .status(StatusId::Success) + .state(StateId::Enabled, "loaded") + .unmapped( + "supervisor_middleware_service_count", + serde_json::json!(current_middleware_services.len()) + ) + .message(format!( + "Supervisor middleware registry reloaded [service_count:{}]", + current_middleware_services.len() + )) + .build() + ); + } + Err(error) => { + // Emit only on the transition into the failed state to avoid + // repeating the same finding on every poll during a + // sustained outage. The startup degrade path emits its own + // event. + if !middleware_sync_pending { + ocsf_emit!( + ConfigStateChangeBuilder::new(ocsf_ctx()) + .severity(SeverityId::Medium) + .status(StatusId::Failure) + .state(StateId::Other, "failed") + .message(format!( + "Supervisor middleware registry reload failed, keeping last-known-good registry [error:{error}]" + )) + .build() + ); + } + middleware_sync_pending = true; + } + } + } + let provider_env_changed = result.provider_env_revision != current_provider_env_revision; if result.config_revision == current_config_revision && !provider_env_changed { continue; } let policy_changed = result.policy_hash != current_policy_hash; - let middleware_changed = - result.supervisor_middleware_services != current_middleware_services; // Log which settings changed. log_setting_changes(¤t_settings, &result.settings); @@ -1740,47 +1840,6 @@ async fn run_policy_poll_loop(ctx: PolicyPollLoopContext) -> Result<()> { } } - if middleware_changed { - match openshell_supervisor_middleware::MiddlewareRegistry::connect_services( - result.supervisor_middleware_services.clone(), - ) - .await - .and_then(|registry| ctx.opa_engine.replace_middleware_registry(registry)) - { - Ok(()) => { - current_middleware_services = result.supervisor_middleware_services.clone(); - ocsf_emit!( - ConfigStateChangeBuilder::new(ocsf_ctx()) - .severity(SeverityId::Informational) - .status(StatusId::Success) - .state(StateId::Enabled, "loaded") - .unmapped( - "supervisor_middleware_service_count", - serde_json::json!(current_middleware_services.len()) - ) - .message(format!( - "Supervisor middleware registry reloaded [service_count:{}]", - current_middleware_services.len() - )) - .build() - ); - } - Err(error) => { - ocsf_emit!( - ConfigStateChangeBuilder::new(ocsf_ctx()) - .severity(SeverityId::Medium) - .status(StatusId::Failure) - .state(StateId::Other, "failed") - .message(format!( - "Supervisor middleware registry reload failed, keeping last-known-good registry [error:{error}]" - )) - .build() - ); - continue; - } - } - } - // Only reload OPA when the policy payload actually changed. if policy_changed { let Some(policy) = result.policy.as_ref() else { From 9bdc7f7736937e2cce793c96f6a4ede1dbff2b3f Mon Sep 17 00:00:00 2001 From: Piotr Mlocek Date: Tue, 30 Jun 2026 17:31:17 -0700 Subject: [PATCH 16/27] fix(supervisor-middleware): ignore unresolved bindings in chain body limit A chain entry whose binding did not resolve reported a zero body limit, which dragged the whole chain's buffer cap to zero and spuriously failed body-bearing requests over capacity even when a resolved middleware could have processed them. Exclude unresolved entries from the limit via a new DescribedChainEntry::is_resolved(); when no entry resolves, skip buffering and apply each entry's on_error directly. Also fix two parallel-test flakes found while validating the change: - Build middleware OCSF events into a Vec and assert on it directly instead of capturing through the global tracing pipeline, whose callsite-interest cache is process-global and raced under parallel runs. - Accumulate the websocket deny response until the reason marker arrives rather than assuming a single read returns the full body. Signed-off-by: Piotr Mlocek --- .../src/lib.rs | 27 +++ .../src/l7/relay.rs | 201 ++++++++++++++---- 2 files changed, 188 insertions(+), 40 deletions(-) diff --git a/crates/openshell-supervisor-middleware/src/lib.rs b/crates/openshell-supervisor-middleware/src/lib.rs index 00324d75f..27302dfe4 100644 --- a/crates/openshell-supervisor-middleware/src/lib.rs +++ b/crates/openshell-supervisor-middleware/src/lib.rs @@ -102,6 +102,14 @@ impl DescribedChainEntry { pub fn on_error(&self) -> OnError { self.entry.on_error } + + /// True when this entry resolved to a registered binding and will be + /// evaluated. When false, the binding is absent from the current registry + /// and the entry is handled entirely by its `on_error` policy, so it + /// imposes no body-buffering limit on the chain. + pub fn is_resolved(&self) -> bool { + self.binding.is_some() + } } #[derive(Debug, Clone)] @@ -1166,6 +1174,25 @@ mod tests { } } + #[tokio::test] + async fn describe_chain_marks_resolved_and_unresolved_entries() { + let unresolved = ChainEntry { + name: "missing".into(), + implementation: "third-party/missing".into(), + config: prost_types::Struct::default(), + on_error: OnError::FailOpen, + }; + let described = ChainRunner::default() + .describe_chain(&[entry("redact", OnError::FailClosed), unresolved]) + .await + .expect("describe chain"); + // The built-in resolves and reports its real limit; the missing binding + // does not resolve and must not contribute a body limit. + assert!(described[0].is_resolved()); + assert_eq!(described[0].max_body_bytes(), 256 * 1024); + assert!(!described[1].is_resolved()); + } + #[tokio::test] async fn descriptors_are_resolved_from_any_middleware_service() { let runner = ChainRunner::new(Arc::new(ScriptedService { diff --git a/crates/openshell-supervisor-network/src/l7/relay.rs b/crates/openshell-supervisor-network/src/l7/relay.rs index 84853f751..3e92a3218 100644 --- a/crates/openshell-supervisor-network/src/l7/relay.rs +++ b/crates/openshell-supervisor-network/src/l7/relay.rs @@ -778,11 +778,18 @@ pub(crate) enum MiddlewareApplyResult { Denied(String), } +/// Smallest body-buffering limit across the entries that actually resolved to a +/// registered binding. Unresolved entries (`is_resolved() == false`) report a +/// zero limit and are excluded here: they are handled by their `on_error` policy +/// in `evaluate_described` without inspecting the body, so letting a zero drag +/// the chain limit to zero would spuriously fail the whole chain over capacity. +/// Returns `None` when no entry resolved, so the caller can skip buffering. fn middleware_chain_body_limit( chain: &[openshell_supervisor_middleware::DescribedChainEntry], ) -> Option { chain .iter() + .filter(|entry| entry.is_resolved()) .map(openshell_supervisor_middleware::DescribedChainEntry::max_body_bytes) .min() } @@ -812,8 +819,20 @@ pub(crate) async fn apply_middleware_chain_for_scheme crate::opa::NetworkInput { } } -fn emit_middleware_events( +/// Build the OCSF events describing a middleware chain outcome, in emission +/// order. Separated from `emit_middleware_events` so tests can assert on the +/// events deterministically without routing through the global tracing pipeline, +/// whose callsite-interest cache is process-global and races under parallel +/// tests. +fn middleware_events( ctx: &L7EvalContext, req: &crate::l7::provider::L7Request, outcome: &openshell_supervisor_middleware::ChainOutcome, -) { +) -> Vec { + let mut events = Vec::new(); for invocation in &outcome.applied { let allowed = invocation.decision == openshell_core::proto::Decision::Allow; let event = HttpActivityBuilder::new(openshell_ocsf::ctx::ctx()) @@ -1000,7 +1025,7 @@ fn emit_middleware_events( invocation.failed )) .build(); - ocsf_emit!(event); + events.push(event); // A middleware that failed but was bypassed under `fail_open` is an // enforcement failure operators must be able to alert on, even though the @@ -1021,7 +1046,7 @@ fn emit_middleware_events( invocation.name )) .build(); - ocsf_emit!(event); + events.push(event); } } if !outcome.allowed && outcome.reason.starts_with("middleware_failed:") { @@ -1033,7 +1058,7 @@ fn emit_middleware_events( )) .message("Required supervisor middleware failed closed") .build(); - ocsf_emit!(event); + events.push(event); } for finding in &outcome.findings { let event = DetectionFindingBuilder::new(openshell_ocsf::ctx::ctx()) @@ -1055,6 +1080,19 @@ fn emit_middleware_events( finding.finding.r#type, finding.finding.count )) .build(); + events.push(event); + } + events +} + +/// Emit the OCSF events describing a middleware chain outcome through the +/// tracing pipeline. +fn emit_middleware_events( + ctx: &L7EvalContext, + req: &crate::l7::provider::L7Request, + outcome: &openshell_supervisor_middleware::ChainOutcome, +) { + for event in middleware_events(ctx, req, outcome) { ocsf_emit!(event); } } @@ -3051,6 +3089,101 @@ network_policies: )); } + #[tokio::test] + async fn body_limit_ignores_unresolved_entries() { + use openshell_supervisor_middleware::{ChainEntry, ChainRunner, OnError}; + + let resolved = ChainEntry { + name: "redact".into(), + implementation: openshell_supervisor_middleware::BUILTIN_SECRETS.into(), + config: prost_types::Struct::default(), + on_error: OnError::FailClosed, + }; + let unresolved = ChainEntry { + name: "missing".into(), + implementation: "third-party/missing".into(), + config: prost_types::Struct::default(), + on_error: OnError::FailOpen, + }; + + // A single unresolved (0-limit) entry must not drag the chain limit to + // zero: the buffer limit reflects only the resolved built-in. + let mixed = ChainRunner::default() + .describe_chain(&[resolved, unresolved.clone()]) + .await + .expect("describe mixed chain"); + assert_eq!(middleware_chain_body_limit(&mixed), Some(256 * 1024)); + + // When nothing resolves, there is no body limit and the caller skips + // buffering entirely. + let none = ChainRunner::default() + .describe_chain(std::slice::from_ref(&unresolved)) + .await + .expect("describe unresolved chain"); + assert_eq!(middleware_chain_body_limit(&none), None); + } + + #[tokio::test] + async fn all_unresolved_fail_open_forwards_body_unbuffered() { + // A chain whose only entry is an unregistered binding has no resolvable + // body limit. Under fail_open the request must pass through with its + // body intact rather than being denied over a phantom zero-byte cap. + let (config, tunnel_engine, ctx) = + middleware_relay_context("third-party/missing", "fail_open"); + let (mut app, mut relay_client) = tokio::io::duplex(8192); + let (mut relay_upstream, mut upstream) = tokio::io::duplex(8192); + let relay = tokio::spawn(async move { + relay_with_inspection( + &config, + tunnel_engine, + &mut relay_client, + &mut relay_upstream, + &ctx, + ) + .await + }); + + let body = br#"{"api_key":"sk-1234567890abcdef"}"#; + let request = format!( + "POST /v1/messages HTTP/1.1\r\nHost: api.example.test\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}", + body.len(), + std::str::from_utf8(body).unwrap() + ); + app.write_all(request.as_bytes()).await.unwrap(); + + let mut upstream_request = [0u8; 1024]; + let n = tokio::time::timeout( + std::time::Duration::from_secs(1), + upstream.read(&mut upstream_request), + ) + .await + .expect("request should reach upstream") + .unwrap(); + let upstream_request = String::from_utf8_lossy(&upstream_request[..n]); + // No middleware ran, so the body is forwarded verbatim. + assert!(upstream_request.contains(r#""api_key":"sk-1234567890abcdef""#)); + + upstream + .write_all(b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\nConnection: close\r\n\r\n") + .await + .unwrap(); + let mut client_response = [0u8; 512]; + let n = tokio::time::timeout( + std::time::Duration::from_secs(1), + app.read(&mut client_response), + ) + .await + .expect("response should reach client") + .unwrap(); + assert!(String::from_utf8_lossy(&client_response[..n]).contains("204 No Content")); + drop(app); + tokio::time::timeout(std::time::Duration::from_secs(1), relay) + .await + .expect("relay should finish") + .unwrap() + .unwrap(); + } + #[test] fn middleware_keeps_the_raw_request_query() { let query = raw_query_from_request_headers( @@ -3095,36 +3228,14 @@ network_policies: assert_eq!(input.scheme, "http"); } - /// Tracing layer that captures emitted `OcsfEvent`s for assertions. - struct OcsfCaptureLayer(Arc>>); - - impl tracing_subscriber::Layer for OcsfCaptureLayer { - fn on_event( - &self, - event: &tracing::Event<'_>, - _ctx: tracing_subscriber::layer::Context<'_, S>, - ) { - if event.metadata().target() == openshell_ocsf::OCSF_TARGET - && let Some(ocsf_event) = openshell_ocsf::clone_current_event() - { - self.0.lock().unwrap().push(ocsf_event); - } - } - } - #[test] fn middleware_ocsf_events_are_audit_safe() { use openshell_supervisor_middleware::{ ChainOutcome, MiddlewareInvocation, NamespacedFinding, }; - use tracing_subscriber::layer::SubscriberExt; const RAW_SECRET: &str = "sk-RAWSECRETVALUE0123456789"; - let events = Arc::new(std::sync::Mutex::new(Vec::new())); - let subscriber = tracing_subscriber::registry().with(OcsfCaptureLayer(Arc::clone(&events))); - let _guard = tracing::subscriber::set_default(subscriber); - let ctx = L7EvalContext { host: "api.example.test".into(), port: 443, @@ -3171,23 +3282,26 @@ network_policies: }], }; - emit_middleware_events(&ctx, &req, &outcome); + // Build the events directly rather than routing through the global + // tracing pipeline: its callsite-interest cache is process-global, so a + // parallel test that emits OCSF with no subscriber installed can cache + // the callsite as disabled and make captured-event assertions flaky. + let events = middleware_events(&ctx, &req, &outcome); - let captured = events.lock().unwrap(); // Per-invocation decisions are HTTP Activity (class 4002). assert!( - captured.iter().any(|e| e.class_uid() == 4002), + events.iter().any(|e| e.class_uid() == 4002), "expected an HTTP Activity event for the middleware invocation" ); // Findings are Detection Finding (class 2004) with the finding's severity. - let finding_event = captured + let finding_event = events .iter() .find(|e| e.class_uid() == 2004) .expect("expected a Detection Finding event"); assert_eq!(finding_event.base().severity, SeverityId::Medium); // No raw payload material may appear in any emitted event. - let serialized = serde_json::to_string(&*captured).expect("serialize events"); + let serialized = serde_json::to_string(&events).expect("serialize events"); assert!( !serialized.contains(RAW_SECRET), "raw secret leaked into OCSF events: {serialized}" @@ -3364,12 +3478,19 @@ network_policies: .await .unwrap(); - let mut response = [0u8; 512]; - let n = tokio::time::timeout(std::time::Duration::from_secs(1), app.read(&mut response)) - .await - .expect("denial should reach client") - .unwrap(); - let response = String::from_utf8_lossy(&response[..n]); + // Accumulate until the reason marker arrives: the deny response can be + // delivered in more than one write, so a single read may return only the + // status line and flake the body assertion. + let mut response = Vec::new(); + let mut buf = [0u8; 512]; + while !String::from_utf8_lossy(&response).contains("middleware_failed") { + match tokio::time::timeout(std::time::Duration::from_secs(1), app.read(&mut buf)).await { + Ok(Ok(0)) | Err(_) => break, // clean EOF, or no more data before the deadline + Ok(Ok(n)) => response.extend_from_slice(&buf[..n]), + Ok(Err(e)) => panic!("read from relay failed: {e}"), + } + } + let response = String::from_utf8_lossy(&response); assert!(response.contains("403 Forbidden")); assert!(response.contains("middleware_failed")); From 9d4b6315e628fad3ac49840706dbbb91d2cb6d8b Mon Sep 17 00:00:00 2001 From: Piotr Mlocek Date: Thu, 2 Jul 2026 13:16:45 -0700 Subject: [PATCH 17/27] fix(supervisor-middleware): harden policy enforcement Signed-off-by: Piotr Mlocek --- crates/openshell-sandbox/src/lib.rs | 7 +- crates/openshell-server/src/grpc/policy.rs | 32 ++++++ .../src/lib.rs | 98 ++++++++++++++++--- .../src/l7/relay.rs | 78 ++++++++++++++- .../src/l7/rest.rs | 29 +++++- 5 files changed, 228 insertions(+), 16 deletions(-) diff --git a/crates/openshell-sandbox/src/lib.rs b/crates/openshell-sandbox/src/lib.rs index 25750ee14..a0fa85937 100644 --- a/crates/openshell-sandbox/src/lib.rs +++ b/crates/openshell-sandbox/src/lib.rs @@ -1527,7 +1527,12 @@ async fn load_policy( let opa_engine = Some(Arc::new(engine)); let policy = SandboxPolicy::try_from(proto_policy.clone())?; - return Ok((policy, opa_engine, Some(proto_policy), middleware_install_pending)); + return Ok(( + policy, + opa_engine, + Some(proto_policy), + middleware_install_pending, + )); } // No policy source available diff --git a/crates/openshell-server/src/grpc/policy.rs b/crates/openshell-server/src/grpc/policy.rs index 97d459813..b1e37f87e 100644 --- a/crates/openshell-server/src/grpc/policy.rs +++ b/crates/openshell-server/src/grpc/policy.rs @@ -3135,6 +3135,18 @@ fn deterministic_policy_hash(policy: &ProtoSandboxPolicy) -> String { hasher.update(key.as_bytes()); hasher.update(value.encode_to_vec()); } + if !policy.network_middlewares.is_empty() { + hasher.update(b"network_middlewares"); + for middleware in &policy.network_middlewares { + let encoded = middleware.encode_to_vec(); + hasher.update( + u64::try_from(encoded.len()) + .expect("protobuf payload length fits in u64") + .to_le_bytes(), + ); + hasher.update(encoded); + } + } hex::encode(hasher.finalize()) } @@ -9315,6 +9327,26 @@ mod tests { assert_ne!(rev_a, rev_b); } + #[test] + fn policy_hash_changes_when_network_middlewares_change() { + let policy_a = ProtoSandboxPolicy::default(); + let policy_b = ProtoSandboxPolicy { + network_middlewares: vec![openshell_core::proto::NetworkMiddlewareConfig { + name: "redact-secrets".into(), + middleware: "openshell/secrets".into(), + on_error: "fail_closed".into(), + ..Default::default() + }], + ..Default::default() + }; + + assert_ne!( + deterministic_policy_hash(&policy_a), + deterministic_policy_hash(&policy_b), + "middleware-only policy changes must produce a new policy hash" + ); + } + #[test] fn config_revision_changes_when_policy_source_changes() { let policy = ProtoSandboxPolicy::default(); diff --git a/crates/openshell-supervisor-middleware/src/lib.rs b/crates/openshell-supervisor-middleware/src/lib.rs index 27302dfe4..d2f8642b1 100644 --- a/crates/openshell-supervisor-middleware/src/lib.rs +++ b/crates/openshell-supervisor-middleware/src/lib.rs @@ -712,6 +712,37 @@ impl ChainRunner { } }; + if decision == Decision::Deny { + for finding in result.findings { + findings.push(NamespacedFinding { + middleware: entry.entry.name.clone(), + finding, + }); + } + if !result.metadata.is_empty() { + metadata.insert( + entry.entry.name.clone(), + result.metadata.into_iter().collect(), + ); + } + applied.push(MiddlewareInvocation { + name: entry.entry.name.clone(), + implementation: entry.entry.implementation.clone(), + decision, + transformed: false, + failed: false, + }); + return Ok(ChainOutcome { + allowed: false, + reason: safe_reason(&result.reason), + body, + added_headers, + findings, + metadata, + applied, + }); + } + if result.has_body && result.body.len() > entry.max_body_bytes { match apply_on_error(entry, "response_body_over_capacity", &mut applied) { OnErrorAction::FailOpen => continue, @@ -774,17 +805,6 @@ impl ChainRunner { transformed, failed: false, }); - if decision == Decision::Deny { - return Ok(ChainOutcome { - allowed: false, - reason: safe_reason(&result.reason), - body, - added_headers, - findings, - metadata, - applied, - }); - } } Ok(ChainOutcome { @@ -1399,6 +1419,62 @@ mod tests { assert!(!outcome.applied[0].failed); } + #[tokio::test] + async fn deny_decision_ignores_unsafe_mutations_under_fail_open() { + let runner = ChainRunner::new(Arc::new(scripted_service( + openshell_core::proto::HttpRequestResult { + decision: Decision::Deny as i32, + reason: "blocked_by_policy".into(), + add_headers: std::iter::once(( + "x-openshell-middleware-inject".to_string(), + "ok\r\nHost: evil".to_string(), + )) + .collect(), + ..allow_result() + }, + ))); + + let outcome = runner + .evaluate(&[entry("guard", OnError::FailOpen)], input("hello")) + .await + .expect("evaluate"); + + assert!(!outcome.allowed); + assert_eq!(outcome.reason, "blocked_by_policy"); + assert!(outcome.added_headers.is_empty()); + assert_eq!(outcome.applied.len(), 1); + assert_eq!(outcome.applied[0].decision, Decision::Deny); + assert!(!outcome.applied[0].failed); + } + + #[tokio::test] + async fn deny_decision_ignores_oversized_replacement_under_fail_open() { + let runner = ChainRunner::new(Arc::new(ScriptedService { + binding_id: BUILTIN_SECRETS.into(), + max_body_bytes: 4, + result: openshell_core::proto::HttpRequestResult { + decision: Decision::Deny as i32, + reason: "blocked_by_policy".into(), + body: b"too large".to_vec(), + has_body: true, + ..allow_result() + }, + })); + + let outcome = runner + .evaluate(&[entry("guard", OnError::FailOpen)], input("safe")) + .await + .expect("evaluate"); + + assert!(!outcome.allowed); + assert_eq!(outcome.reason, "blocked_by_policy"); + assert_eq!(outcome.body, b"safe"); + assert_eq!(outcome.applied.len(), 1); + assert_eq!(outcome.applied[0].decision, Decision::Deny); + assert!(!outcome.applied[0].transformed); + assert!(!outcome.applied[0].failed); + } + #[tokio::test] async fn metadata_and_findings_are_namespaced_per_config() { let runner = ChainRunner::new(Arc::new(scripted_service( diff --git a/crates/openshell-supervisor-network/src/l7/relay.rs b/crates/openshell-supervisor-network/src/l7/relay.rs index 3e92a3218..61daf3345 100644 --- a/crates/openshell-supervisor-network/src/l7/relay.rs +++ b/crates/openshell-supervisor-network/src/l7/relay.rs @@ -824,7 +824,14 @@ pub(crate) async fn apply_middleware_chain_for_scheme break, // clean EOF, or no more data before the deadline Ok(Ok(n)) => response.extend_from_slice(&buf[..n]), Ok(Err(e)) => panic!("read from relay failed: {e}"), diff --git a/crates/openshell-supervisor-network/src/l7/rest.rs b/crates/openshell-supervisor-network/src/l7/rest.rs index 15825d1b2..4f2d37f08 100644 --- a/crates/openshell-supervisor-network/src/l7/rest.rs +++ b/crates/openshell-supervisor-network/src/l7/rest.rs @@ -828,7 +828,7 @@ pub(crate) enum BufferResult { OverCapacity { recoverable: bool }, } -pub(crate) async fn buffer_request_body_for_middleware( +pub(crate) async fn buffer_request_body_for_middleware( req: &L7Request, client: &mut C, generation_guard: Option<&PolicyGenerationGuard>, @@ -839,7 +839,7 @@ pub(crate) async fn buffer_request_body_for_middleware( .windows(4) .position(|w| w == b"\r\n\r\n") .map_or(req.raw_header.len(), |p| p + 4); - let headers = req.raw_header[..header_end].to_vec(); + let mut headers = req.raw_header[..header_end].to_vec(); let already_read = &req.raw_header[header_end..]; match req.body_length { BodyLength::None => Ok(BufferResult::Buffered(BufferedRequestBody { @@ -860,6 +860,9 @@ pub(crate) async fn buffer_request_body_for_middleware( let mut body = Vec::with_capacity(len); body.extend_from_slice(&already_read[..initial_len]); let mut remaining = len.saturating_sub(initial_len); + if remaining > 0 && already_read.is_empty() { + acknowledge_expect_continue(client, &mut headers).await?; + } let mut buf = [0u8; RELAY_BUF_SIZE]; while remaining > 0 { let to_read = remaining.min(buf.len()); @@ -887,6 +890,9 @@ pub(crate) async fn buffer_request_body_for_middleware( // we have already consumed wire bytes from the client stream and // cannot re-enter the normal raw relay path without a separate // splice-through buffer. + if already_read.is_empty() { + acknowledge_expect_continue(client, &mut headers).await?; + } Ok( collect_chunked_body(client, already_read, generation_guard, Some(max_body_bytes)) .await @@ -898,6 +904,25 @@ pub(crate) async fn buffer_request_body_for_middleware( } } +async fn acknowledge_expect_continue( + client: &mut C, + headers: &mut Vec, +) -> Result<()> { + let header_str = + std::str::from_utf8(headers).map_err(|_| miette!("HTTP headers contain invalid UTF-8"))?; + if !has_expect_continue(header_str) { + return Ok(()); + } + + client + .write_all(b"HTTP/1.1 100 Continue\r\n\r\n") + .await + .into_diagnostic()?; + client.flush().await.into_diagnostic()?; + *headers = strip_header(headers, "expect")?; + Ok(()) +} + pub(crate) fn rebuild_request_with_buffered_body( req: &L7Request, headers: &[u8], From 2b7cf4e1f9d4d70efa614998d8034637b9c8c90a Mon Sep 17 00:00:00 2001 From: Piotr Mlocek Date: Thu, 2 Jul 2026 14:08:08 -0700 Subject: [PATCH 18/27] feat(ocsf): enrich middleware shorthand logs Signed-off-by: Piotr Mlocek --- crates/openshell-ocsf/src/format/shorthand.rs | 129 ++++++++++++++++-- docs/observability/logging.mdx | 9 +- 2 files changed, 125 insertions(+), 13 deletions(-) diff --git a/crates/openshell-ocsf/src/format/shorthand.rs b/crates/openshell-ocsf/src/format/shorthand.rs index 3143b6d68..8765f1e06 100644 --- a/crates/openshell-ocsf/src/format/shorthand.rs +++ b/crates/openshell-ocsf/src/format/shorthand.rs @@ -7,7 +7,22 @@ use crate::events::OcsfEvent; use crate::events::base_event::BaseEventData; -use crate::objects::Url; +use crate::objects::{Evidence, Url}; + +fn finding_evidence_value<'a>(evidences: Option<&'a [Evidence]>, key: &str) -> Option<&'a str> { + evidences? + .iter() + .filter_map(|evidence| evidence.data.as_ref()?.as_object()) + .find_map(|data| data.get(key)?.as_str()) +} + +fn message_bool_value<'a>(message: Option<&'a str>, key: &str) -> Option<&'a str> { + let prefix = format!("{key}="); + message? + .split_ascii_whitespace() + .find_map(|field| field.strip_prefix(&prefix)) + .filter(|value| matches!(*value, "true" | "false")) +} /// Format a timestamp (ms since epoch) as `HH:MM:SS.mmm`. /// @@ -195,10 +210,32 @@ impl OcsfEvent { .and_then(|r| r.url.as_ref()) .map(Url::to_display_string) .unwrap_or_default(); + let transformed = e + .firewall_rule + .as_ref() + .filter(|rule| rule.rule_type == "middleware") + .and_then(|_| message_bool_value(e.base.message.as_deref(), "transformed")); + let failed = e + .firewall_rule + .as_ref() + .filter(|rule| rule.rule_type == "middleware") + .and_then(|_| message_bool_value(e.base.message.as_deref(), "failed")); let rule_ctx = e .firewall_rule .as_ref() - .map(|r| format!(" [policy:{} engine:{}]", r.name, r.rule_type)) + .map(|r| { + let mut context = vec![ + format!("policy:{}", r.name), + format!("engine:{}", r.rule_type), + ]; + if let Some(value) = transformed { + context.push(format!("transformed:{value}")); + } + if let Some(value) = failed { + context.push(format!("failed:{value}")); + } + format!(" [{}]", context.join(" ")) + }) .unwrap_or_default(); // For denied events, surface the reason from status_detail let reason_ctx = if action == "DENIED" { @@ -280,16 +317,26 @@ impl OcsfEvent { } Self::DetectionFinding(e) => { - let disposition = e - .disposition - .map_or_else(|| "UNKNOWN".to_string(), |d| d.label().to_uppercase()); + let disposition = e.disposition.map_or_else( + || e.base.activity_name.to_uppercase(), + |d| d.label().to_uppercase(), + ); let title = &e.finding_info.title; - let confidence_ctx = e - .confidence - .map(|c| format!(" [confidence:{}]", c.label().to_lowercase())) - .unwrap_or_default(); - - format!("FINDING:{disposition} {sev} \"{title}\"{confidence_ctx}") + let mut context = vec![format!("type:{}", e.finding_info.uid)]; + if let Some(middleware) = + finding_evidence_value(e.evidences.as_deref(), "middleware") + { + context.push(format!("middleware:{middleware}")); + } + if let Some(count) = finding_evidence_value(e.evidences.as_deref(), "count") { + context.push(format!("count:{count}")); + } + if let Some(confidence) = e.confidence { + context.push(format!("confidence:{}", confidence.label().to_lowercase())); + } + let context = format!(" [{}]", context.join(" ")); + + format!("FINDING:{disposition} {sev} \"{title}\"{context}") } Self::ApplicationLifecycle(e) => { @@ -534,6 +581,37 @@ mod tests { ); } + #[test] + fn test_http_activity_shorthand_includes_middleware_outcome() { + let mut base = base(4002, "HTTP Activity", 4, "Network Activity", 99, "Other"); + base.set_message( + "MIDDLEWARE prototype-content-guard example/content-guard decision=Allow transformed=false failed=true", + ); + let event = OcsfEvent::HttpActivity(HttpActivityEvent { + base, + http_request: Some(HttpRequest::new( + "POST", + Url::new("http", "httpbin.org", "/anything", 443), + )), + http_response: None, + src_endpoint: None, + dst_endpoint: None, + proxy_endpoint: None, + actor: None, + firewall_rule: Some(FirewallRule::new("httpbin", "middleware")), + action: Some(ActionId::Allowed), + disposition: Some(DispositionId::Allowed), + observation_point_id: None, + is_src_dst_assignment_known: None, + }); + + let shorthand = event.format_shorthand(); + assert_eq!( + shorthand, + "HTTP:POST [INFO] ALLOWED POST http://httpbin.org:443/anything [policy:httpbin engine:middleware transformed:false failed:true]" + ); + } + #[test] fn test_network_activity_shorthand_denied_shows_reason() { let mut b = base(4001, "Network Activity", 4, "Network Activity", 1, "Open"); @@ -863,8 +941,35 @@ mod tests { let shorthand = event.format_shorthand(); assert_eq!( shorthand, - "FINDING:BLOCKED [HIGH] \"NSSH1 Nonce Replay Attack\" [confidence:high]" + "FINDING:BLOCKED [HIGH] \"NSSH1 Nonce Replay Attack\" [type:nssh1-replay-abc confidence:high]" + ); + } + + #[test] + fn test_detection_finding_shorthand_uses_activity_and_safe_evidence() { + let event = OcsfEvent::DetectionFinding(DetectionFindingEvent { + base: base(2004, "Detection Finding", 2, "Findings", 1, "Create"), + finding_info: FindingInfo::new("content_guard.match", "configured content matched"), + evidences: Some(vec![Evidence::from_pairs(&[ + ("middleware", "prototype-content-guard"), + ("count", "1"), + ("matched_content", "must-not-appear"), + ])]), + attacks: None, + remediation: None, + is_alert: None, + confidence: None, + risk_level: None, + action: None, + disposition: None, + }); + + let shorthand = event.format_shorthand(); + assert_eq!( + shorthand, + "FINDING:CREATE [INFO] \"configured content matched\" [type:content_guard.match middleware:prototype-content-guard count:1]" ); + assert!(!shorthand.contains("must-not-appear")); } #[test] diff --git a/docs/observability/logging.mdx b/docs/observability/logging.mdx index dcfe9f19d..206618107 100644 --- a/docs/observability/logging.mdx +++ b/docs/observability/logging.mdx @@ -94,7 +94,7 @@ CLASS:ACTIVITY [SEVERITY] ACTION DETAILS [CONTEXT] - SSH: peer address and authentication type - Process: `name(pid)` with exit code or command line - Config: description of what changed -- Finding: quoted title with confidence level +- Finding: quoted title with the stable finding type, optional confidence, and allowlisted evidence fields when available **Context** in brackets at the end provides the policy rule and enforcement engine that produced the decision. @@ -130,6 +130,13 @@ An HTTP request to a non-default port. HTTP log URLs include the port whenever i OCSF HTTP:GET [INFO] ALLOWED GET http://api.internal.corp:8080/v1/status [policy:internal_api engine:opa] ``` +A supervisor middleware HTTP event records whether it transformed the request. If the middleware also emits a finding, that remains a separate event: + +```text +OCSF HTTP:POST [INFO] ALLOWED POST http://httpbin.org:443/anything [policy:httpbin engine:middleware transformed:true failed:false] +OCSF FINDING:CREATE [MED] "configured content matched" [type:content_guard.match middleware:prototype-content-guard count:1] +``` + Proxy and SSH servers ready: ```text From 0ef948bff0b16405ecb7e6b44d43520ab40df519 Mon Sep 17 00:00:00 2001 From: Piotr Mlocek Date: Thu, 2 Jul 2026 15:34:52 -0700 Subject: [PATCH 19/27] fix(policy): initialize network middleware test fixtures Signed-off-by: Piotr Mlocek --- crates/openshell-policy/src/lib.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/crates/openshell-policy/src/lib.rs b/crates/openshell-policy/src/lib.rs index 29585a623..efcca2fca 100644 --- a/crates/openshell-policy/src/lib.rs +++ b/crates/openshell-policy/src/lib.rs @@ -2678,6 +2678,7 @@ network_policies: filesystem: None, landlock: None, network_policies: HashMap::new(), + network_middlewares: Vec::new(), }; assert!(validate_sandbox_policy(&policy).is_ok()); } @@ -2693,6 +2694,7 @@ network_policies: filesystem: None, landlock: None, network_policies: HashMap::new(), + network_middlewares: Vec::new(), }; assert!(validate_sandbox_policy(&policy).is_ok()); } @@ -2771,6 +2773,7 @@ network_policies: filesystem: None, landlock: None, network_policies: HashMap::new(), + network_middlewares: Vec::new(), }; assert!(validate_sandbox_policy(&policy).is_ok()); } From 400ca0fa6c87a2ac8580aa1826f37ebca7c42cb6 Mon Sep 17 00:00:00 2001 From: Piotr Mlocek Date: Thu, 2 Jul 2026 16:39:23 -0700 Subject: [PATCH 20/27] fix(policy): preserve immutable validation precedence Signed-off-by: Piotr Mlocek --- crates/openshell-server/src/grpc/policy.rs | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/crates/openshell-server/src/grpc/policy.rs b/crates/openshell-server/src/grpc/policy.rs index b1e37f87e..21d7e0254 100644 --- a/crates/openshell-server/src/grpc/policy.rs +++ b/crates/openshell-server/src/grpc/policy.rs @@ -1848,12 +1848,14 @@ async fn handle_update_config_inner( validate_no_reserved_provider_policy_keys(&new_policy)?; } + if let Some(baseline_policy) = spec.policy.as_ref() { + validate_static_fields_unchanged(baseline_policy, &new_policy)?; + } + validate_policy_safety(&new_policy)?; crate::middleware::validate_policy(state.middleware_registry.as_ref(), &new_policy).await?; - if let Some(baseline_policy) = spec.policy.as_ref() { - validate_static_fields_unchanged(baseline_policy, &new_policy)?; - } else { + if spec.policy.is_none() { // Backfill spec.policy using CAS (first-time policy discovery) let _sandbox_sync_guard = state.compute.sandbox_sync_guard().await; let sandbox_id = sandbox.object_id().to_string(); From 1f6aec16e84bc6b7ff9777ae7f87483d7bb38624 Mon Sep 17 00:00:00 2001 From: Piotr Mlocek Date: Thu, 2 Jul 2026 18:00:20 -0700 Subject: [PATCH 21/27] feat(ocsf): log middleware denial reasons Signed-off-by: Piotr Mlocek --- crates/openshell-ocsf/src/format/shorthand.rs | 76 +++++++++---------- .../src/l7/relay.rs | 42 +++++++++- docs/extensibility/supervisor-middleware.mdx | 1 + docs/observability/logging.mdx | 4 +- proto/middleware.proto | 4 +- 5 files changed, 84 insertions(+), 43 deletions(-) diff --git a/crates/openshell-ocsf/src/format/shorthand.rs b/crates/openshell-ocsf/src/format/shorthand.rs index 8765f1e06..22bd1759b 100644 --- a/crates/openshell-ocsf/src/format/shorthand.rs +++ b/crates/openshell-ocsf/src/format/shorthand.rs @@ -97,22 +97,20 @@ fn truncate_with_ellipsis(text: &str, max: usize) -> String { format!("{}...", &text[..end]) } -/// Format a `[reason:...]` tag from `status_detail` (or `message` fallback) -/// for denied events. Returns an empty string if neither field is set. -fn reason_tag(base: &BaseEventData) -> String { - let text = base - .status_detail - .as_deref() - .or(base.message.as_deref()) - .unwrap_or(""); +fn reason_text(text: Option<&str>) -> Option { + let text = text?; if text.is_empty() { - return String::new(); + return None; } let text = text.replace(['\n', '\r'], " "); - format!( - " [reason:{}]", - truncate_with_ellipsis(&text, MAX_REASON_LEN) - ) + Some(truncate_with_ellipsis(&text, MAX_REASON_LEN)) +} + +/// Format a `[reason:...]` tag from `status_detail` (or `message` fallback) +/// for denied events. Returns an empty string if neither field is set. +fn reason_tag(base: &BaseEventData) -> String { + reason_text(base.status_detail.as_deref().or(base.message.as_deref())) + .map_or_else(String::new, |text| format!(" [reason:{text}]")) } fn message_tag(base: &BaseEventData) -> String { @@ -210,35 +208,37 @@ impl OcsfEvent { .and_then(|r| r.url.as_ref()) .map(Url::to_display_string) .unwrap_or_default(); - let transformed = e + let is_middleware = e .firewall_rule .as_ref() - .filter(|rule| rule.rule_type == "middleware") - .and_then(|_| message_bool_value(e.base.message.as_deref(), "transformed")); - let failed = e - .firewall_rule - .as_ref() - .filter(|rule| rule.rule_type == "middleware") - .and_then(|_| message_bool_value(e.base.message.as_deref(), "failed")); + .is_some_and(|rule| rule.rule_type == "middleware"); let rule_ctx = e .firewall_rule .as_ref() - .map(|r| { - let mut context = vec![ - format!("policy:{}", r.name), - format!("engine:{}", r.rule_type), - ]; - if let Some(value) = transformed { - context.push(format!("transformed:{value}")); - } - if let Some(value) = failed { - context.push(format!("failed:{value}")); - } - format!(" [{}]", context.join(" ")) - }) + .map(|r| format!(" [policy:{} engine:{}]", r.name, r.rule_type)) .unwrap_or_default(); - // For denied events, surface the reason from status_detail - let reason_ctx = if action == "DENIED" { + let outcome_ctx = if is_middleware { + let mut context = Vec::new(); + if let Some(value) = + message_bool_value(e.base.message.as_deref(), "transformed") + { + context.push(format!("transformed:{value}")); + } + if let Some(value) = message_bool_value(e.base.message.as_deref(), "failed") { + context.push(format!("failed:{value}")); + } + if action == "DENIED" + && let Some(reason) = reason_text(e.base.status_detail.as_deref()) + { + // Keep the free-form reason last so the preceding fields remain easy to parse. + context.push(format!("reason:{reason}")); + } + if context.is_empty() { + String::new() + } else { + format!(" [{}]", context.join(" ")) + } + } else if action == "DENIED" { reason_tag(&e.base) } else { String::new() @@ -255,7 +255,7 @@ impl OcsfEvent { (false, true) => format!(" {action}"), (false, false) => format!(" {action}{arrow}"), }; - format!("HTTP:{method} {sev}{detail}{rule_ctx}{reason_ctx}") + format!("HTTP:{method} {sev}{detail}{rule_ctx}{outcome_ctx}") } Self::SshActivity(e) => { @@ -608,7 +608,7 @@ mod tests { let shorthand = event.format_shorthand(); assert_eq!( shorthand, - "HTTP:POST [INFO] ALLOWED POST http://httpbin.org:443/anything [policy:httpbin engine:middleware transformed:false failed:true]" + "HTTP:POST [INFO] ALLOWED POST http://httpbin.org:443/anything [policy:httpbin engine:middleware] [transformed:false failed:true]" ); } diff --git a/crates/openshell-supervisor-network/src/l7/relay.rs b/crates/openshell-supervisor-network/src/l7/relay.rs index 61daf3345..29bbece56 100644 --- a/crates/openshell-supervisor-network/src/l7/relay.rs +++ b/crates/openshell-supervisor-network/src/l7/relay.rs @@ -1000,7 +1000,7 @@ fn middleware_events( let mut events = Vec::new(); for invocation in &outcome.applied { let allowed = invocation.decision == openshell_core::proto::Decision::Allow; - let event = HttpActivityBuilder::new(openshell_ocsf::ctx::ctx()) + let mut event = HttpActivityBuilder::new(openshell_ocsf::ctx::ctx()) .activity(ActivityId::Other) .action(if allowed { ActionId::Allowed @@ -1030,8 +1030,13 @@ fn middleware_events( invocation.decision, invocation.transformed, invocation.failed - )) - .build(); + )); + if !allowed && !outcome.reason.is_empty() { + event = event + .status(StatusId::Failure) + .status_detail(&outcome.reason); + } + let event = event.build(); events.push(event); // A middleware that failed but was bypassed under `fail_open` is an @@ -3381,6 +3386,37 @@ network_policies: ); // Safe finding metadata is still present. assert!(serialized.contains("secret.common")); + + let denied_outcome = ChainOutcome { + allowed: false, + reason: "request matched configured policy".into(), + body: Vec::new(), + added_headers: BTreeMap::new(), + findings: Vec::new(), + metadata: BTreeMap::new(), + applied: vec![MiddlewareInvocation { + name: "content-guard".into(), + implementation: "example/content-guard".into(), + decision: openshell_core::proto::Decision::Deny, + transformed: false, + failed: false, + }], + }; + let denied_events = middleware_events(&ctx, &req, &denied_outcome); + let denied_http = denied_events + .iter() + .find(|event| event.class_uid() == 4002) + .expect("expected denied HTTP Activity event"); + assert_eq!( + denied_http.base().status_detail.as_deref(), + Some("request matched configured policy") + ); + assert_eq!( + denied_http.format_shorthand(), + "HTTP:POST [MED] DENIED POST http://api.example.test:443/v1/messages \ + [policy:rest_api engine:middleware] \ + [transformed:false failed:false reason:request matched configured policy]" + ); } #[tokio::test] diff --git a/docs/extensibility/supervisor-middleware.mdx b/docs/extensibility/supervisor-middleware.mdx index 0ebd518bd..f2ea0c98a 100644 --- a/docs/extensibility/supervisor-middleware.mdx +++ b/docs/extensibility/supervisor-middleware.mdx @@ -120,6 +120,7 @@ When the effective sandbox configuration changes, a running supervisor validates Middleware activity is emitted through OpenShell's OCSF logging: - Each invocation records its policy-local middleware name, binding, decision, transformation state, and failure state. +- A denied invocation records the service-provided audit-safe reason without request content, configured terms, credentials, or other secrets. - A bypass under `fail_open` emits a detection finding. - A required stage that fails closed emits a high-severity detection finding. - Findings include the service-provided type and label plus aggregate counts. Middleware services should keep those fields audit-safe and omit request content or matched values. diff --git a/docs/observability/logging.mdx b/docs/observability/logging.mdx index 206618107..31281de88 100644 --- a/docs/observability/logging.mdx +++ b/docs/observability/logging.mdx @@ -133,7 +133,7 @@ OCSF HTTP:GET [INFO] ALLOWED GET http://api.internal.corp:8080/v1/status [policy A supervisor middleware HTTP event records whether it transformed the request. If the middleware also emits a finding, that remains a separate event: ```text -OCSF HTTP:POST [INFO] ALLOWED POST http://httpbin.org:443/anything [policy:httpbin engine:middleware transformed:true failed:false] +OCSF HTTP:POST [INFO] ALLOWED POST http://httpbin.org:443/anything [policy:httpbin engine:middleware] [transformed:true failed:false] OCSF FINDING:CREATE [MED] "configured content matched" [type:content_guard.match middleware:prototype-content-guard count:1] ``` @@ -167,6 +167,8 @@ OCSF CONFIG:LOADED [INFO] Policy reloaded successfully [policy_hash:0cc0c2b52557 Denied `NET:` and `HTTP:` events carry a `[reason:...]` suffix that surfaces the decision detail from the event's `status_detail` field. The reason helps distinguish between policy misses, SSRF hardening, and L7 enforcement without inspecting the full OCSF JSONL record. +For supervisor middleware denials, `status_detail` contains the service-provided audit-safe reason from the gRPC response. + Common reason phrases emitted by the sandbox include: | Reason | Meaning | diff --git a/proto/middleware.proto b/proto/middleware.proto index 9b988b930..7a4eb28df 100644 --- a/proto/middleware.proto +++ b/proto/middleware.proto @@ -148,7 +148,9 @@ message Finding { message HttpRequestResult { // Allow or deny decision for this request. Decision decision = 1; - // Human-readable reason used for diagnostics and denied responses. + // Audit-safe human-readable reason used for diagnostics, denied responses, + // and security logs. Must not contain request content, configured terms, + // credentials, or other secrets. string reason = 2; // Replacement request body when has_body is true. bytes body = 3; From 6fdc5e3e736d068231278b2a081aceb8860881e8 Mon Sep 17 00:00:00 2001 From: Piotr Mlocek Date: Fri, 3 Jul 2026 11:43:26 -0700 Subject: [PATCH 22/27] refactor(ocsf): decouple shorthand from middleware Signed-off-by: Piotr Mlocek --- crates/openshell-ocsf/src/builders/finding.rs | 14 ++ crates/openshell-ocsf/src/builders/http.rs | 16 +++ crates/openshell-ocsf/src/format/shorthand.rs | 123 ++++++++---------- .../src/l7/relay.rs | 19 ++- docs/observability/logging.mdx | 8 +- 5 files changed, 101 insertions(+), 79 deletions(-) diff --git a/crates/openshell-ocsf/src/builders/finding.rs b/crates/openshell-ocsf/src/builders/finding.rs index 177c9159d..fecb1e927 100644 --- a/crates/openshell-ocsf/src/builders/finding.rs +++ b/crates/openshell-ocsf/src/builders/finding.rs @@ -25,6 +25,7 @@ pub struct DetectionFindingBuilder<'a> { risk_level: Option, message: Option, log_source: Option, + unmapped: serde_json::Map, } impl<'a> DetectionFindingBuilder<'a> { @@ -45,6 +46,7 @@ impl<'a> DetectionFindingBuilder<'a> { risk_level: None, message: None, log_source: None, + unmapped: serde_json::Map::new(), } } @@ -74,6 +76,13 @@ impl<'a> DetectionFindingBuilder<'a> { self } + /// Add a source-specific attribute that is not defined by the OCSF class. + #[must_use] + pub fn unmapped(mut self, key: &str, value: impl Into) -> Self { + self.unmapped.insert(key.to_string(), value.into()); + self + } + /// Add a remediation description. #[must_use] pub fn remediation(mut self, desc: &str) -> Self { @@ -122,6 +131,9 @@ impl<'a> DetectionFindingBuilder<'a> { self.severity, metadata, ); + if !self.unmapped.is_empty() { + base.unmapped = Some(serde_json::Value::Object(self.unmapped)); + } self.ctx.apply_common_fields(&mut base, None, self.message); OcsfEvent::DetectionFinding(DetectionFindingEvent { @@ -169,6 +181,7 @@ mod tests { .is_alert(true) .confidence(ConfidenceId::High) .risk_level(RiskLevelId::High) + .unmapped("source_rule", "nonce_replay") .finding_info( FindingInfo::new("nssh1-replay-abc", "NSSH1 Nonce Replay Attack") .with_desc("A nonce was replayed."), @@ -188,5 +201,6 @@ mod tests { assert_eq!(json["finding_info"]["title"], "NSSH1 Nonce Replay Attack"); assert_eq!(json["is_alert"], true); assert_eq!(json["confidence"], "High"); + assert_eq!(json["unmapped"]["source_rule"], "nonce_replay"); } } diff --git a/crates/openshell-ocsf/src/builders/http.rs b/crates/openshell-ocsf/src/builders/http.rs index 4fa33eb1c..ae91cf7c4 100644 --- a/crates/openshell-ocsf/src/builders/http.rs +++ b/crates/openshell-ocsf/src/builders/http.rs @@ -25,6 +25,7 @@ pub struct HttpActivityBuilder<'a> { firewall_rule: Option, message: Option, status_detail: Option, + unmapped: serde_json::Map, } impl<'a> HttpActivityBuilder<'a> { @@ -45,6 +46,7 @@ impl<'a> HttpActivityBuilder<'a> { firewall_rule: None, message: None, status_detail: None, + unmapped: serde_json::Map::new(), } } @@ -69,6 +71,13 @@ impl<'a> HttpActivityBuilder<'a> { self } + /// Add a source-specific attribute that is not defined by the OCSF class. + #[must_use] + pub fn unmapped(mut self, key: &str, value: impl Into) -> Self { + self.unmapped.insert(key.to_string(), value.into()); + self + } + #[must_use] pub fn build(self) -> OcsfEvent { let activity_name = self.activity.http_label().to_string(); @@ -86,6 +95,9 @@ impl<'a> HttpActivityBuilder<'a> { if let Some(detail) = self.status_detail { base.set_status_detail(detail); } + if !self.unmapped.is_empty() { + base.unmapped = Some(serde_json::Value::Object(self.unmapped)); + } self.ctx .apply_common_fields(&mut base, self.status, self.message); @@ -157,11 +169,15 @@ mod tests { .firewall_rule("aws_iam", "ssrf") .message("FORWARD blocked: allowed_ips check failed") .status_detail("resolves to always-blocked address") + .unmapped("attempt", 2) + .unmapped("cached", true) .build(); let json = event.to_json().unwrap(); assert_eq!(json["class_uid"], 4002); assert_eq!(json["status_detail"], "resolves to always-blocked address"); + assert_eq!(json["unmapped"]["attempt"], 2); + assert_eq!(json["unmapped"]["cached"], true); assert_eq!(json["action_id"], 2); // Denied } } diff --git a/crates/openshell-ocsf/src/format/shorthand.rs b/crates/openshell-ocsf/src/format/shorthand.rs index 22bd1759b..5e5f93244 100644 --- a/crates/openshell-ocsf/src/format/shorthand.rs +++ b/crates/openshell-ocsf/src/format/shorthand.rs @@ -7,22 +7,7 @@ use crate::events::OcsfEvent; use crate::events::base_event::BaseEventData; -use crate::objects::{Evidence, Url}; - -fn finding_evidence_value<'a>(evidences: Option<&'a [Evidence]>, key: &str) -> Option<&'a str> { - evidences? - .iter() - .filter_map(|evidence| evidence.data.as_ref()?.as_object()) - .find_map(|data| data.get(key)?.as_str()) -} - -fn message_bool_value<'a>(message: Option<&'a str>, key: &str) -> Option<&'a str> { - let prefix = format!("{key}="); - message? - .split_ascii_whitespace() - .find_map(|field| field.strip_prefix(&prefix)) - .filter(|value| matches!(*value, "true" | "false")) -} +use crate::objects::Url; /// Format a timestamp (ms since epoch) as `HH:MM:SS.mmm`. /// @@ -113,6 +98,43 @@ fn reason_tag(base: &BaseEventData) -> String { .map_or_else(String::new, |text| format!(" [reason:{text}]")) } +fn unmapped_fields(base: &BaseEventData) -> Vec { + base.unmapped + .as_ref() + .and_then(serde_json::Value::as_object) + .into_iter() + .flatten() + .filter_map(|(key, value)| { + let value = match value { + serde_json::Value::Bool(value) => value.to_string(), + serde_json::Value::Number(value) => value.to_string(), + serde_json::Value::String(value) => { + let value = value.replace(['\n', '\r'], " "); + truncate_with_ellipsis(&value, MAX_REASON_LEN) + } + _ => return None, + }; + Some(format!("{key}:{value}")) + }) + .collect() +} + +fn unmapped_context(base: &BaseEventData, include_reason: bool) -> String { + let mut fields = unmapped_fields(base); + + if include_reason + && let Some(reason) = reason_text(base.status_detail.as_deref().or(base.message.as_deref())) + { + fields.push(format!("reason:{reason}")); + } + + if fields.is_empty() { + String::new() + } else { + format!(" [{}]", fields.join(" ")) + } +} + fn message_tag(base: &BaseEventData) -> String { let text = base.message.as_deref().unwrap_or(""); if text.is_empty() { @@ -208,41 +230,12 @@ impl OcsfEvent { .and_then(|r| r.url.as_ref()) .map(Url::to_display_string) .unwrap_or_default(); - let is_middleware = e - .firewall_rule - .as_ref() - .is_some_and(|rule| rule.rule_type == "middleware"); let rule_ctx = e .firewall_rule .as_ref() .map(|r| format!(" [policy:{} engine:{}]", r.name, r.rule_type)) .unwrap_or_default(); - let outcome_ctx = if is_middleware { - let mut context = Vec::new(); - if let Some(value) = - message_bool_value(e.base.message.as_deref(), "transformed") - { - context.push(format!("transformed:{value}")); - } - if let Some(value) = message_bool_value(e.base.message.as_deref(), "failed") { - context.push(format!("failed:{value}")); - } - if action == "DENIED" - && let Some(reason) = reason_text(e.base.status_detail.as_deref()) - { - // Keep the free-form reason last so the preceding fields remain easy to parse. - context.push(format!("reason:{reason}")); - } - if context.is_empty() { - String::new() - } else { - format!(" [{}]", context.join(" ")) - } - } else if action == "DENIED" { - reason_tag(&e.base) - } else { - String::new() - }; + let outcome_ctx = unmapped_context(&e.base, action == "DENIED"); let arrow = if actor_str.is_empty() { format!(" {method} {url_str}") } else { @@ -323,14 +316,7 @@ impl OcsfEvent { ); let title = &e.finding_info.title; let mut context = vec![format!("type:{}", e.finding_info.uid)]; - if let Some(middleware) = - finding_evidence_value(e.evidences.as_deref(), "middleware") - { - context.push(format!("middleware:{middleware}")); - } - if let Some(count) = finding_evidence_value(e.evidences.as_deref(), "count") { - context.push(format!("count:{count}")); - } + context.extend(unmapped_fields(&e.base)); if let Some(confidence) = e.confidence { context.push(format!("confidence:{}", confidence.label().to_lowercase())); } @@ -582,11 +568,10 @@ mod tests { } #[test] - fn test_http_activity_shorthand_includes_middleware_outcome() { + fn test_http_activity_shorthand_includes_unmapped_attributes() { let mut base = base(4002, "HTTP Activity", 4, "Network Activity", 99, "Other"); - base.set_message( - "MIDDLEWARE prototype-content-guard example/content-guard decision=Allow transformed=false failed=true", - ); + base.add_unmapped("attempt", serde_json::json!(2)); + base.add_unmapped("cached", serde_json::json!(true)); let event = OcsfEvent::HttpActivity(HttpActivityEvent { base, http_request: Some(HttpRequest::new( @@ -598,7 +583,7 @@ mod tests { dst_endpoint: None, proxy_endpoint: None, actor: None, - firewall_rule: Some(FirewallRule::new("httpbin", "middleware")), + firewall_rule: Some(FirewallRule::new("httpbin", "extension")), action: Some(ActionId::Allowed), disposition: Some(DispositionId::Allowed), observation_point_id: None, @@ -608,7 +593,7 @@ mod tests { let shorthand = event.format_shorthand(); assert_eq!( shorthand, - "HTTP:POST [INFO] ALLOWED POST http://httpbin.org:443/anything [policy:httpbin engine:middleware] [transformed:false failed:true]" + "HTTP:POST [INFO] ALLOWED POST http://httpbin.org:443/anything [policy:httpbin engine:extension] [attempt:2 cached:true]" ); } @@ -946,15 +931,17 @@ mod tests { } #[test] - fn test_detection_finding_shorthand_uses_activity_and_safe_evidence() { + fn test_detection_finding_shorthand_uses_activity_and_safe_unmapped_attributes() { + let mut base = base(2004, "Detection Finding", 2, "Findings", 1, "Create"); + base.add_unmapped("count", serde_json::json!(1)); + base.add_unmapped("source", serde_json::json!("content_guard")); let event = OcsfEvent::DetectionFinding(DetectionFindingEvent { - base: base(2004, "Detection Finding", 2, "Findings", 1, "Create"), + base, finding_info: FindingInfo::new("content_guard.match", "configured content matched"), - evidences: Some(vec![Evidence::from_pairs(&[ - ("middleware", "prototype-content-guard"), - ("count", "1"), - ("matched_content", "must-not-appear"), - ])]), + evidences: Some(vec![Evidence::from_pairs(&[( + "matched_content", + "must-not-appear", + )])]), attacks: None, remediation: None, is_alert: None, @@ -967,7 +954,7 @@ mod tests { let shorthand = event.format_shorthand(); assert_eq!( shorthand, - "FINDING:CREATE [INFO] \"configured content matched\" [type:content_guard.match middleware:prototype-content-guard count:1]" + "FINDING:CREATE [INFO] \"configured content matched\" [type:content_guard.match count:1 source:content_guard]" ); assert!(!shorthand.contains("must-not-appear")); } diff --git a/crates/openshell-supervisor-network/src/l7/relay.rs b/crates/openshell-supervisor-network/src/l7/relay.rs index 29bbece56..329d4b42d 100644 --- a/crates/openshell-supervisor-network/src/l7/relay.rs +++ b/crates/openshell-supervisor-network/src/l7/relay.rs @@ -1023,13 +1023,11 @@ fn middleware_events( )) .dst_endpoint(Endpoint::from_domain(&ctx.host, ctx.port)) .firewall_rule(&ctx.policy_name, "middleware") + .unmapped("transformed", invocation.transformed) + .unmapped("failed", invocation.failed) .message(format!( - "MIDDLEWARE {} {} decision={:?} transformed={} failed={}", - invocation.name, - invocation.implementation, - invocation.decision, - invocation.transformed, - invocation.failed + "MIDDLEWARE {} {} decision={:?}", + invocation.name, invocation.implementation, invocation.decision )); if !allowed && !outcome.reason.is_empty() { event = event @@ -1053,6 +1051,8 @@ fn middleware_events( ("middleware", invocation.name.as_str()), ("implementation", invocation.implementation.as_str()), ]) + .unmapped("middleware", invocation.name.as_str()) + .unmapped("implementation", invocation.implementation.as_str()) .message(format!( "Middleware {} failed and was bypassed (fail_open)", invocation.name @@ -1087,6 +1087,8 @@ fn middleware_events( ("middleware", &finding.middleware), ("count", &finding.finding.count.to_string()), ]) + .unmapped("middleware", finding.middleware.as_str()) + .unmapped("count", finding.finding.count) .message(format!( "Middleware finding {} count={}", finding.finding.r#type, finding.finding.count @@ -3411,11 +3413,14 @@ network_policies: denied_http.base().status_detail.as_deref(), Some("request matched configured policy") ); + let denied_json = denied_http.to_json().expect("serialize denied event"); + assert_eq!(denied_json["unmapped"]["transformed"], false); + assert_eq!(denied_json["unmapped"]["failed"], false); assert_eq!( denied_http.format_shorthand(), "HTTP:POST [MED] DENIED POST http://api.example.test:443/v1/messages \ [policy:rest_api engine:middleware] \ - [transformed:false failed:false reason:request matched configured policy]" + [failed:false transformed:false reason:request matched configured policy]" ); } diff --git a/docs/observability/logging.mdx b/docs/observability/logging.mdx index 31281de88..6cd24c70e 100644 --- a/docs/observability/logging.mdx +++ b/docs/observability/logging.mdx @@ -94,9 +94,9 @@ CLASS:ACTIVITY [SEVERITY] ACTION DETAILS [CONTEXT] - SSH: peer address and authentication type - Process: `name(pid)` with exit code or command line - Config: description of what changed -- Finding: quoted title with the stable finding type, optional confidence, and allowlisted evidence fields when available +- Finding: quoted title with the stable finding type, optional confidence, and source-specific context attributes when available -**Context** in brackets at the end provides the policy rule and enforcement engine that produced the decision. +**Context** in brackets provides structured fields such as policy provenance, source-specific attributes, and denial reasons. ### Examples @@ -133,8 +133,8 @@ OCSF HTTP:GET [INFO] ALLOWED GET http://api.internal.corp:8080/v1/status [policy A supervisor middleware HTTP event records whether it transformed the request. If the middleware also emits a finding, that remains a separate event: ```text -OCSF HTTP:POST [INFO] ALLOWED POST http://httpbin.org:443/anything [policy:httpbin engine:middleware] [transformed:true failed:false] -OCSF FINDING:CREATE [MED] "configured content matched" [type:content_guard.match middleware:prototype-content-guard count:1] +OCSF HTTP:POST [INFO] ALLOWED POST http://httpbin.org:443/anything [policy:httpbin engine:middleware] [failed:false transformed:true] +OCSF FINDING:CREATE [MED] "configured content matched" [type:content_guard.match count:1 middleware:prototype-content-guard] ``` Proxy and SSH servers ready: From d27085cb73566d9588cc3f7f202090aa25af926b Mon Sep 17 00:00:00 2001 From: Piotr Mlocek Date: Fri, 3 Jul 2026 12:08:51 -0700 Subject: [PATCH 23/27] refactor(supervisor-middleware): clarify runtime integration Signed-off-by: Piotr Mlocek --- crates/openshell-sandbox/src/lib.rs | 176 +++++---- .../src/l7/middleware.rs | 353 ++++++++++++++++++ .../src/l7/mod.rs | 1 + .../src/l7/relay.rs | 353 +----------------- .../openshell-supervisor-network/src/opa.rs | 117 +----- .../openshell-supervisor-network/src/proxy.rs | 6 +- 6 files changed, 467 insertions(+), 539 deletions(-) create mode 100644 crates/openshell-supervisor-network/src/l7/middleware.rs diff --git a/crates/openshell-sandbox/src/lib.rs b/crates/openshell-sandbox/src/lib.rs index a0fa85937..b8c4eeb9a 100644 --- a/crates/openshell-sandbox/src/lib.rs +++ b/crates/openshell-sandbox/src/lib.rs @@ -128,7 +128,7 @@ pub async fn run_sandbox( // Load policy and initialize OPA engine let openshell_endpoint_for_proxy = openshell_endpoint.clone(); let sandbox_name_for_agg = sandbox.clone(); - let (mut policy, opa_engine, retained_proto, middleware_install_pending) = load_policy( + let (mut policy, opa_engine, retained_proto, middleware_registry_status) = load_policy( sandbox_id.clone(), sandbox, openshell_endpoint.clone(), @@ -423,7 +423,7 @@ pub async fn run_sandbox( ocsf_enabled: poll_ocsf_enabled, provider_credentials: poll_provider_credentials, policy_local_ctx: poll_policy_local, - middleware_install_pending, + middleware_registry_status, }; tokio::spawn(async move { @@ -1371,11 +1371,7 @@ async fn load_policy( SandboxPolicy, Option>, Option, - // True when operator-registered middleware could not be connected at - // startup and the engine kept the built-in registry. The policy poll loop - // retries the install so a recovered service is picked up without a config - // change. - bool, + MiddlewareRegistryStatus, )> { // File mode: load OPA engine from rego rules + YAML data (dev override) if let (Some(policy_file), Some(data_file)) = (&policy_rules, &policy_data) { @@ -1406,7 +1402,12 @@ async fn load_policy( }; enrich_sandbox_baseline_paths(&mut policy); // File mode has no operator-registered middleware to connect. - return Ok((policy, Some(Arc::new(engine)), None, false)); + return Ok(( + policy, + Some(Arc::new(engine)), + None, + MiddlewareRegistryStatus::Synchronized, + )); } // gRPC mode: fetch typed proto policy, construct OPA engine from baked rules + proto data @@ -1490,14 +1491,11 @@ async fn load_policy( info!("Creating OPA engine from proto policy data"); let engine = OpaEngine::from_proto(&proto_policy)?; // Connect operator-registered middleware services. A connect/describe - // failure must not abort sandbox startup: unlike the previous hard - // failure, we degrade to the built-in registry and let each request's - // `on_error` policy govern matched traffic (deny for fail_closed, allow - // for fail_open). The policy poll loop retries the install so a - // recovered service is picked up without a config change. This mirrors - // the resilient live-reload path. + // failure keeps the built-in registry active so each request's + // `on_error` policy governs matched traffic. The policy poll loop + // retries the install without waiting for a config change. let middleware_services = sandbox_config.supervisor_middleware_services.clone(); - let middleware_install_pending = match grpc_retry("Middleware connect", || { + let middleware_registry_status = match grpc_retry("Middleware connect", || { openshell_supervisor_middleware::MiddlewareRegistry::connect_services( middleware_services.clone(), ) @@ -1505,7 +1503,7 @@ async fn load_policy( .await .and_then(|registry| engine.replace_middleware_registry(registry)) { - Ok(()) => false, + Ok(()) => MiddlewareRegistryStatus::Synchronized, Err(error) => { ocsf_emit!( ConfigStateChangeBuilder::new(ocsf_ctx()) @@ -1521,7 +1519,7 @@ async fn load_policy( )) .build() ); - true + MiddlewareRegistryStatus::NeedsReconciliation } }; let opa_engine = Some(Arc::new(engine)); @@ -1531,7 +1529,7 @@ async fn load_policy( policy, opa_engine, Some(proto_policy), - middleware_install_pending, + middleware_registry_status, )); } @@ -1651,6 +1649,12 @@ fn discover_policy_from_path(path: &std::path::Path) -> openshell_core::proto::S } } +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum MiddlewareRegistryStatus { + Synchronized, + NeedsReconciliation, +} + /// Background loop that polls the server for policy updates. /// /// When a new version is detected, attempts to reload the OPA engine via @@ -1668,10 +1672,65 @@ struct PolicyPollLoopContext { ocsf_enabled: Arc, provider_credentials: ProviderCredentialState, policy_local_ctx: Option>, - /// True when `load_policy` degraded to the built-in middleware registry - /// because operator services could not be connected at startup. The poll - /// loop retries the install until it succeeds. - middleware_install_pending: bool, + middleware_registry_status: MiddlewareRegistryStatus, +} + +async fn reconcile_middleware_registry( + opa_engine: &OpaEngine, + desired_services: &[openshell_core::proto::SupervisorMiddlewareService], + current_services: &mut Vec, + status: &mut MiddlewareRegistryStatus, +) { + if *status == MiddlewareRegistryStatus::Synchronized + && desired_services == current_services.as_slice() + { + return; + } + + match openshell_supervisor_middleware::MiddlewareRegistry::connect_services( + desired_services.to_vec(), + ) + .await + .and_then(|registry| opa_engine.replace_middleware_registry(registry)) + { + Ok(()) => { + current_services.clear(); + current_services.extend_from_slice(desired_services); + *status = MiddlewareRegistryStatus::Synchronized; + ocsf_emit!( + ConfigStateChangeBuilder::new(ocsf_ctx()) + .severity(SeverityId::Informational) + .status(StatusId::Success) + .state(StateId::Enabled, "loaded") + .unmapped( + "supervisor_middleware_service_count", + serde_json::json!(current_services.len()) + ) + .message(format!( + "Supervisor middleware registry reloaded [service_count:{}]", + current_services.len() + )) + .build() + ); + } + Err(error) => { + // Emit only on the transition into the failed state to avoid + // repeating the same finding on every poll during an outage. + if *status == MiddlewareRegistryStatus::Synchronized { + ocsf_emit!( + ConfigStateChangeBuilder::new(ocsf_ctx()) + .severity(SeverityId::Medium) + .status(StatusId::Failure) + .state(StateId::Other, "failed") + .message(format!( + "Supervisor middleware registry reload failed, keeping last-known-good registry [error:{error}]" + )) + .build() + ); + } + *status = MiddlewareRegistryStatus::NeedsReconciliation; + } + } } async fn run_policy_poll_loop(ctx: PolicyPollLoopContext) -> Result<()> { @@ -1684,10 +1743,7 @@ async fn run_policy_poll_loop(ctx: PolicyPollLoopContext) -> Result<()> { let mut current_provider_env_revision: u64 = ctx.provider_credentials.snapshot().revision; let mut current_policy_hash = String::new(); let mut current_middleware_services = Vec::new(); - // Set when a middleware install is outstanding (degraded at startup or a - // failed reload). Drives a retry on every poll, independent of the config - // revision, so a recovered operator service is picked up promptly. - let mut middleware_sync_pending = ctx.middleware_install_pending; + let mut middleware_registry_status = ctx.middleware_registry_status; let mut current_settings: std::collections::HashMap< String, openshell_core::proto::EffectiveSetting, @@ -1723,63 +1779,17 @@ async fn run_policy_poll_loop(ctx: PolicyPollLoopContext) -> Result<()> { } }; - // Reconcile the supervisor middleware registry before evaluating the - // rest of the config. This runs independently of the config revision so - // an install that degraded at startup (or failed on an earlier poll) is - // retried here, letting a recovered operator service be picked up - // without waiting for a policy change. A failure keeps the - // last-known-good registry; the request path stays governed by each - // middleware's `on_error` policy, and a config change is still applied - // below rather than being blocked by the middleware outage. - if middleware_sync_pending - || result.supervisor_middleware_services != current_middleware_services - { - match openshell_supervisor_middleware::MiddlewareRegistry::connect_services( - result.supervisor_middleware_services.clone(), - ) - .await - .and_then(|registry| ctx.opa_engine.replace_middleware_registry(registry)) - { - Ok(()) => { - current_middleware_services = result.supervisor_middleware_services.clone(); - middleware_sync_pending = false; - ocsf_emit!( - ConfigStateChangeBuilder::new(ocsf_ctx()) - .severity(SeverityId::Informational) - .status(StatusId::Success) - .state(StateId::Enabled, "loaded") - .unmapped( - "supervisor_middleware_service_count", - serde_json::json!(current_middleware_services.len()) - ) - .message(format!( - "Supervisor middleware registry reloaded [service_count:{}]", - current_middleware_services.len() - )) - .build() - ); - } - Err(error) => { - // Emit only on the transition into the failed state to avoid - // repeating the same finding on every poll during a - // sustained outage. The startup degrade path emits its own - // event. - if !middleware_sync_pending { - ocsf_emit!( - ConfigStateChangeBuilder::new(ocsf_ctx()) - .severity(SeverityId::Medium) - .status(StatusId::Failure) - .state(StateId::Other, "failed") - .message(format!( - "Supervisor middleware registry reload failed, keeping last-known-good registry [error:{error}]" - )) - .build() - ); - } - middleware_sync_pending = true; - } - } - } + // Reconcile independently of the config revision so a recovered + // operator service is picked up without waiting for a policy change. + // Failure preserves the last-known-good registry and does not block + // the remaining config updates below. + reconcile_middleware_registry( + &ctx.opa_engine, + &result.supervisor_middleware_services, + &mut current_middleware_services, + &mut middleware_registry_status, + ) + .await; let provider_env_changed = result.provider_env_revision != current_provider_env_revision; if result.config_revision == current_config_revision && !provider_env_changed { diff --git a/crates/openshell-supervisor-network/src/l7/middleware.rs b/crates/openshell-supervisor-network/src/l7/middleware.rs new file mode 100644 index 000000000..d8a8b6c30 --- /dev/null +++ b/crates/openshell-supervisor-network/src/l7/middleware.rs @@ -0,0 +1,353 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Supervisor middleware application for L7 requests. + +use crate::l7::relay::L7EvalContext; +use crate::opa::PolicyGenerationGuard; +use miette::{Result, miette}; +use openshell_ocsf::{ + ActionId, ActivityId, DetectionFindingBuilder, DispositionId, Endpoint, FindingInfo, + HttpActivityBuilder, HttpRequest, SeverityId, StatusId, Url as OcsfUrl, ocsf_emit, +}; +use std::collections::BTreeMap; +use std::path::PathBuf; +use tokio::io::{AsyncRead, AsyncWrite}; + +pub(crate) enum MiddlewareApplyResult { + Allowed(crate::l7::provider::L7Request), + Denied(String), +} + +/// Smallest body-buffering limit across the entries that actually resolved to a +/// registered binding. Unresolved entries (`is_resolved() == false`) report a +/// zero limit and are excluded here: they are handled by their `on_error` policy +/// in `evaluate_described` without inspecting the body, so letting a zero drag +/// the chain limit to zero would spuriously fail the whole chain over capacity. +/// Returns `None` when no entry resolved, so the caller can skip buffering. +pub(super) fn middleware_chain_body_limit( + chain: &[openshell_supervisor_middleware::DescribedChainEntry], +) -> Option { + chain + .iter() + .filter(|entry| entry.is_resolved()) + .map(openshell_supervisor_middleware::DescribedChainEntry::max_body_bytes) + .min() +} + +pub(crate) async fn apply_middleware_chain( + req: crate::l7::provider::L7Request, + client: &mut C, + ctx: &L7EvalContext, + chain: Vec, + runner: &openshell_supervisor_middleware::ChainRunner, + generation_guard: &PolicyGenerationGuard, +) -> Result { + apply_middleware_chain_for_scheme(req, client, ctx, "https", chain, runner, generation_guard) + .await +} + +pub(crate) async fn apply_middleware_chain_for_scheme( + req: crate::l7::provider::L7Request, + client: &mut C, + ctx: &L7EvalContext, + scheme: &str, + chain: Vec, + runner: &openshell_supervisor_middleware::ChainRunner, + generation_guard: &PolicyGenerationGuard, +) -> Result { + if chain.is_empty() { + return Ok(MiddlewareApplyResult::Allowed(req)); + } + let chain = runner.describe_chain(&chain).await?; + let Some(max_body_bytes) = middleware_chain_body_limit(&chain) else { + // No entry resolved to a registered binding, so nothing inspects the + // body. Apply each entry's `on_error` policy without buffering (an + // unresolved binding is handled before the body is read) and forward + // the original request unchanged if the chain allows. + let input = middleware_request_input( + scheme, + &req, + ctx, + BTreeMap::new(), + String::new(), + Vec::new(), + ); + let outcome = runner.evaluate_described(&chain, input).await?; + emit_middleware_events(ctx, &req, &outcome); + return Ok(if outcome.allowed { + MiddlewareApplyResult::Allowed(req) + } else { + MiddlewareApplyResult::Denied(outcome.reason) + }); + }; + let buffered = match crate::l7::rest::buffer_request_body_for_middleware( + &req, + client, + Some(generation_guard), + max_body_bytes, + ) + .await? + { + crate::l7::rest::BufferResult::Buffered(buffered) => buffered, + crate::l7::rest::BufferResult::OverCapacity { recoverable } => { + return Ok(resolve_unbuffered_body(ctx, req, &chain, recoverable)); + } + }; + let headers = safe_middleware_headers(&buffered.headers)?; + let query = raw_query_from_request_headers(&buffered.headers)?; + let input = middleware_request_input(scheme, &req, ctx, headers, query, buffered.body); + let outcome = runner.evaluate_described(&chain, input).await?; + emit_middleware_events(ctx, &req, &outcome); + let rebuilt = crate::l7::rest::rebuild_request_with_buffered_body( + &req, + &buffered.headers, + &outcome.body, + &outcome.added_headers, + )?; + if outcome.allowed { + Ok(MiddlewareApplyResult::Allowed(rebuilt)) + } else { + Ok(MiddlewareApplyResult::Denied(outcome.reason)) + } +} + +pub(super) fn middleware_request_input( + scheme: &str, + req: &crate::l7::provider::L7Request, + ctx: &L7EvalContext, + headers: BTreeMap, + query: String, + body: Vec, +) -> openshell_supervisor_middleware::HttpRequestInput { + openshell_supervisor_middleware::HttpRequestInput { + request_id: uuid::Uuid::new_v4().to_string(), + sandbox_id: openshell_ocsf::ctx::ctx().sandbox_id.clone(), + scheme: scheme.into(), + host: ctx.host.clone(), + port: ctx.port, + method: req.action.clone(), + path: req.target.clone(), + query, + headers, + body, + } +} + +pub(super) fn raw_query_from_request_headers(headers: &[u8]) -> Result { + let header_str = + std::str::from_utf8(headers).map_err(|_| miette!("HTTP headers contain invalid UTF-8"))?; + let target = header_str + .lines() + .next() + .and_then(|line| line.split_whitespace().nth(1)) + .ok_or_else(|| miette!("HTTP request line is missing a target"))?; + Ok(target + .split_once('?') + .map_or_else(String::new, |(_, query)| query.to_string())) +} + +/// Apply the chain's `on_error` policy when the request body cannot be buffered +/// for inspection because it exceeds the size cap. The RFC treats an unbufferable +/// body as an `on_error` event: it is denied unless every attached middleware is +/// `fail_open`, and passing it through is only safe when no bytes were consumed. +pub(super) fn resolve_unbuffered_body( + ctx: &L7EvalContext, + req: crate::l7::provider::L7Request, + chain: &[openshell_supervisor_middleware::DescribedChainEntry], + recoverable: bool, +) -> MiddlewareApplyResult { + let all_fail_open = chain + .iter() + .all(|entry| entry.on_error() == openshell_supervisor_middleware::OnError::FailOpen); + if recoverable && all_fail_open { + emit_middleware_body_unavailable(ctx, false); + return MiddlewareApplyResult::Allowed(req); + } + emit_middleware_body_unavailable(ctx, true); + MiddlewareApplyResult::Denied("middleware_failed: request_body_over_capacity".into()) +} + +fn emit_middleware_body_unavailable(ctx: &L7EvalContext, denied: bool) { + let event = DetectionFindingBuilder::new(openshell_ocsf::ctx::ctx()) + .severity(if denied { + SeverityId::High + } else { + SeverityId::Medium + }) + .finding_info(FindingInfo::new( + "openshell.middleware.body_unavailable", + "Supervisor middleware could not inspect request body", + )) + .evidence_pairs(&[ + ("policy", ctx.policy_name.as_str()), + ("host", ctx.host.as_str()), + ("disposition", if denied { "denied" } else { "fail_open" }), + ]) + .message(if denied { + "Request body exceeded middleware inspection cap; denied" + } else { + "Request body exceeded middleware inspection cap; passed through (fail_open)" + }) + .build(); + ocsf_emit!(event); +} + +fn safe_middleware_headers(headers: &[u8]) -> Result> { + let header_str = + std::str::from_utf8(headers).map_err(|_| miette!("HTTP headers contain invalid UTF-8"))?; + let mut out = BTreeMap::new(); + for line in header_str.lines().skip(1) { + let Some((name, value)) = line.split_once(':') else { + continue; + }; + let name = name.trim().to_ascii_lowercase(); + if name.is_empty() + || matches!( + name.as_str(), + "authorization" | "cookie" | "host" | "content-length" | "transfer-encoding" + ) + || name.starts_with("x-amz-") + || name.starts_with("x-openshell-credential") + { + continue; + } + out.insert(name, value.trim().to_string()); + } + Ok(out) +} + +pub(super) fn middleware_network_input(ctx: &L7EvalContext) -> crate::opa::NetworkInput { + crate::opa::NetworkInput { + host: ctx.host.clone(), + port: ctx.port, + binary_path: PathBuf::from(&ctx.binary_path), + binary_sha256: String::new(), + ancestors: ctx.ancestors.iter().map(PathBuf::from).collect(), + cmdline_paths: ctx.cmdline_paths.iter().map(PathBuf::from).collect(), + } +} + +/// Build the OCSF events describing a middleware chain outcome, in emission +/// order. Separated from `emit_middleware_events` so tests can assert on the +/// events deterministically without routing through the global tracing pipeline, +/// whose callsite-interest cache is process-global and races under parallel +/// tests. +pub(super) fn middleware_events( + ctx: &L7EvalContext, + req: &crate::l7::provider::L7Request, + outcome: &openshell_supervisor_middleware::ChainOutcome, +) -> Vec { + let mut events = Vec::new(); + for invocation in &outcome.applied { + let allowed = invocation.decision == openshell_core::proto::Decision::Allow; + let mut event = HttpActivityBuilder::new(openshell_ocsf::ctx::ctx()) + .activity(ActivityId::Other) + .action(if allowed { + ActionId::Allowed + } else { + ActionId::Denied + }) + .disposition(if allowed { + DispositionId::Allowed + } else { + DispositionId::Blocked + }) + .severity(if allowed { + SeverityId::Informational + } else { + SeverityId::Medium + }) + .http_request(HttpRequest::new( + &req.action, + OcsfUrl::new("http", &ctx.host, &req.target, ctx.port), + )) + .dst_endpoint(Endpoint::from_domain(&ctx.host, ctx.port)) + .firewall_rule(&ctx.policy_name, "middleware") + .unmapped("transformed", invocation.transformed) + .unmapped("failed", invocation.failed) + .message(format!( + "MIDDLEWARE {} {} decision={:?}", + invocation.name, invocation.implementation, invocation.decision + )); + if !allowed && !outcome.reason.is_empty() { + event = event + .status(StatusId::Failure) + .status_detail(&outcome.reason); + } + let event = event.build(); + events.push(event); + + // A middleware that failed but was bypassed under `fail_open` is an + // enforcement failure operators must be able to alert on, even though the + // request proceeded. + if invocation.failed && allowed { + let event = DetectionFindingBuilder::new(openshell_ocsf::ctx::ctx()) + .severity(SeverityId::Medium) + .finding_info(FindingInfo::new( + "openshell.middleware.failure", + "Supervisor middleware failed open", + )) + .evidence_pairs(&[ + ("middleware", invocation.name.as_str()), + ("implementation", invocation.implementation.as_str()), + ]) + .unmapped("middleware", invocation.name.as_str()) + .unmapped("implementation", invocation.implementation.as_str()) + .message(format!( + "Middleware {} failed and was bypassed (fail_open)", + invocation.name + )) + .build(); + events.push(event); + } + } + if !outcome.allowed && outcome.reason.starts_with("middleware_failed:") { + let event = DetectionFindingBuilder::new(openshell_ocsf::ctx::ctx()) + .severity(SeverityId::High) + .finding_info(FindingInfo::new( + "openshell.middleware.failure", + "Supervisor middleware failure", + )) + .message("Required supervisor middleware failed closed") + .build(); + events.push(event); + } + for finding in &outcome.findings { + let event = DetectionFindingBuilder::new(openshell_ocsf::ctx::ctx()) + .severity(match finding.finding.severity.as_str() { + "high" => SeverityId::High, + "low" => SeverityId::Low, + _ => SeverityId::Medium, + }) + .finding_info(FindingInfo::new( + &finding.finding.r#type, + &finding.finding.label, + )) + .evidence_pairs(&[ + ("middleware", &finding.middleware), + ("count", &finding.finding.count.to_string()), + ]) + .unmapped("middleware", finding.middleware.as_str()) + .unmapped("count", finding.finding.count) + .message(format!( + "Middleware finding {} count={}", + finding.finding.r#type, finding.finding.count + )) + .build(); + events.push(event); + } + events +} + +/// Emit the OCSF events describing a middleware chain outcome through the +/// tracing pipeline. +fn emit_middleware_events( + ctx: &L7EvalContext, + req: &crate::l7::provider::L7Request, + outcome: &openshell_supervisor_middleware::ChainOutcome, +) { + for event in middleware_events(ctx, req, outcome) { + ocsf_emit!(event); + } +} diff --git a/crates/openshell-supervisor-network/src/l7/mod.rs b/crates/openshell-supervisor-network/src/l7/mod.rs index d10c19b8e..691aeb87d 100644 --- a/crates/openshell-supervisor-network/src/l7/mod.rs +++ b/crates/openshell-supervisor-network/src/l7/mod.rs @@ -12,6 +12,7 @@ pub mod graphql; pub(crate) mod http; pub mod inference; pub mod jsonrpc; +pub(crate) mod middleware; pub mod path; pub mod provider; pub mod relay; diff --git a/crates/openshell-supervisor-network/src/l7/relay.rs b/crates/openshell-supervisor-network/src/l7/relay.rs index 329d4b42d..87ab2f39f 100644 --- a/crates/openshell-supervisor-network/src/l7/relay.rs +++ b/crates/openshell-supervisor-network/src/l7/relay.rs @@ -7,6 +7,14 @@ //! Parses each request within the tunnel, evaluates it against OPA policy, //! and either forwards or denies the request. +use crate::l7::middleware::{ + MiddlewareApplyResult, apply_middleware_chain, middleware_network_input, +}; +#[cfg(test)] +use crate::l7::middleware::{ + middleware_chain_body_limit, middleware_events, middleware_request_input, + raw_query_from_request_headers, resolve_unbuffered_body, +}; use crate::l7::provider::{L7Provider, RelayOutcome}; use crate::l7::rest::WebSocketExtensionMode; use crate::l7::{EnforcementMode, L7EndpointConfig, L7Protocol, L7RequestInfo}; @@ -15,12 +23,11 @@ use miette::{IntoDiagnostic, Result, miette}; use openshell_core::activity::{ActivitySender, try_record_activity}; use openshell_core::secrets::{self, SecretResolver}; use openshell_ocsf::{ - ActionId, ActivityId, DetectionFindingBuilder, DispositionId, Endpoint, FindingInfo, - HttpActivityBuilder, HttpRequest, NetworkActivityBuilder, SeverityId, StatusId, Url as OcsfUrl, - ocsf_emit, + ActionId, ActivityId, DispositionId, Endpoint, HttpActivityBuilder, HttpRequest, + NetworkActivityBuilder, SeverityId, StatusId, Url as OcsfUrl, ocsf_emit, }; +#[cfg(test)] use std::collections::BTreeMap; -use std::path::PathBuf; use std::sync::Arc; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tracing::{debug, warn}; @@ -773,344 +780,6 @@ fn jsonrpc_engine_type(protocol: L7Protocol) -> &'static str { } } -pub(crate) enum MiddlewareApplyResult { - Allowed(crate::l7::provider::L7Request), - Denied(String), -} - -/// Smallest body-buffering limit across the entries that actually resolved to a -/// registered binding. Unresolved entries (`is_resolved() == false`) report a -/// zero limit and are excluded here: they are handled by their `on_error` policy -/// in `evaluate_described` without inspecting the body, so letting a zero drag -/// the chain limit to zero would spuriously fail the whole chain over capacity. -/// Returns `None` when no entry resolved, so the caller can skip buffering. -fn middleware_chain_body_limit( - chain: &[openshell_supervisor_middleware::DescribedChainEntry], -) -> Option { - chain - .iter() - .filter(|entry| entry.is_resolved()) - .map(openshell_supervisor_middleware::DescribedChainEntry::max_body_bytes) - .min() -} - -pub(crate) async fn apply_middleware_chain( - req: crate::l7::provider::L7Request, - client: &mut C, - ctx: &L7EvalContext, - chain: Vec, - runner: &openshell_supervisor_middleware::ChainRunner, - generation_guard: &PolicyGenerationGuard, -) -> Result { - apply_middleware_chain_for_scheme(req, client, ctx, "https", chain, runner, generation_guard) - .await -} - -pub(crate) async fn apply_middleware_chain_for_scheme( - req: crate::l7::provider::L7Request, - client: &mut C, - ctx: &L7EvalContext, - scheme: &str, - chain: Vec, - runner: &openshell_supervisor_middleware::ChainRunner, - generation_guard: &PolicyGenerationGuard, -) -> Result { - if chain.is_empty() { - return Ok(MiddlewareApplyResult::Allowed(req)); - } - let chain = runner.describe_chain(&chain).await?; - let Some(max_body_bytes) = middleware_chain_body_limit(&chain) else { - // No entry resolved to a registered binding, so nothing inspects the - // body. Apply each entry's `on_error` policy without buffering (an - // unresolved binding is handled before the body is read) and forward - // the original request unchanged if the chain allows. - let input = middleware_request_input( - scheme, - &req, - ctx, - BTreeMap::new(), - String::new(), - Vec::new(), - ); - let outcome = runner.evaluate_described(&chain, input).await?; - emit_middleware_events(ctx, &req, &outcome); - return Ok(if outcome.allowed { - MiddlewareApplyResult::Allowed(req) - } else { - MiddlewareApplyResult::Denied(outcome.reason) - }); - }; - let buffered = match crate::l7::rest::buffer_request_body_for_middleware( - &req, - client, - Some(generation_guard), - max_body_bytes, - ) - .await? - { - crate::l7::rest::BufferResult::Buffered(buffered) => buffered, - crate::l7::rest::BufferResult::OverCapacity { recoverable } => { - return Ok(resolve_unbuffered_body(ctx, req, &chain, recoverable)); - } - }; - let headers = safe_middleware_headers(&buffered.headers)?; - let query = raw_query_from_request_headers(&buffered.headers)?; - let input = middleware_request_input(scheme, &req, ctx, headers, query, buffered.body); - let outcome = runner.evaluate_described(&chain, input).await?; - emit_middleware_events(ctx, &req, &outcome); - let rebuilt = crate::l7::rest::rebuild_request_with_buffered_body( - &req, - &buffered.headers, - &outcome.body, - &outcome.added_headers, - )?; - if outcome.allowed { - Ok(MiddlewareApplyResult::Allowed(rebuilt)) - } else { - Ok(MiddlewareApplyResult::Denied(outcome.reason)) - } -} - -fn middleware_request_input( - scheme: &str, - req: &crate::l7::provider::L7Request, - ctx: &L7EvalContext, - headers: BTreeMap, - query: String, - body: Vec, -) -> openshell_supervisor_middleware::HttpRequestInput { - openshell_supervisor_middleware::HttpRequestInput { - request_id: uuid::Uuid::new_v4().to_string(), - sandbox_id: openshell_ocsf::ctx::ctx().sandbox_id.clone(), - scheme: scheme.into(), - host: ctx.host.clone(), - port: ctx.port, - method: req.action.clone(), - path: req.target.clone(), - query, - headers, - body, - } -} - -fn raw_query_from_request_headers(headers: &[u8]) -> Result { - let header_str = - std::str::from_utf8(headers).map_err(|_| miette!("HTTP headers contain invalid UTF-8"))?; - let target = header_str - .lines() - .next() - .and_then(|line| line.split_whitespace().nth(1)) - .ok_or_else(|| miette!("HTTP request line is missing a target"))?; - Ok(target - .split_once('?') - .map_or_else(String::new, |(_, query)| query.to_string())) -} - -/// Apply the chain's `on_error` policy when the request body cannot be buffered -/// for inspection because it exceeds the size cap. The RFC treats an unbufferable -/// body as an `on_error` event: it is denied unless every attached middleware is -/// `fail_open`, and passing it through is only safe when no bytes were consumed. -fn resolve_unbuffered_body( - ctx: &L7EvalContext, - req: crate::l7::provider::L7Request, - chain: &[openshell_supervisor_middleware::DescribedChainEntry], - recoverable: bool, -) -> MiddlewareApplyResult { - let all_fail_open = chain - .iter() - .all(|entry| entry.on_error() == openshell_supervisor_middleware::OnError::FailOpen); - if recoverable && all_fail_open { - emit_middleware_body_unavailable(ctx, false); - return MiddlewareApplyResult::Allowed(req); - } - emit_middleware_body_unavailable(ctx, true); - MiddlewareApplyResult::Denied("middleware_failed: request_body_over_capacity".into()) -} - -fn emit_middleware_body_unavailable(ctx: &L7EvalContext, denied: bool) { - let event = DetectionFindingBuilder::new(openshell_ocsf::ctx::ctx()) - .severity(if denied { - SeverityId::High - } else { - SeverityId::Medium - }) - .finding_info(FindingInfo::new( - "openshell.middleware.body_unavailable", - "Supervisor middleware could not inspect request body", - )) - .evidence_pairs(&[ - ("policy", ctx.policy_name.as_str()), - ("host", ctx.host.as_str()), - ("disposition", if denied { "denied" } else { "fail_open" }), - ]) - .message(if denied { - "Request body exceeded middleware inspection cap; denied" - } else { - "Request body exceeded middleware inspection cap; passed through (fail_open)" - }) - .build(); - ocsf_emit!(event); -} - -fn safe_middleware_headers(headers: &[u8]) -> Result> { - let header_str = - std::str::from_utf8(headers).map_err(|_| miette!("HTTP headers contain invalid UTF-8"))?; - let mut out = BTreeMap::new(); - for line in header_str.lines().skip(1) { - let Some((name, value)) = line.split_once(':') else { - continue; - }; - let name = name.trim().to_ascii_lowercase(); - if name.is_empty() - || matches!( - name.as_str(), - "authorization" | "cookie" | "host" | "content-length" | "transfer-encoding" - ) - || name.starts_with("x-amz-") - || name.starts_with("x-openshell-credential") - { - continue; - } - out.insert(name, value.trim().to_string()); - } - Ok(out) -} - -fn middleware_network_input(ctx: &L7EvalContext) -> crate::opa::NetworkInput { - crate::opa::NetworkInput { - host: ctx.host.clone(), - port: ctx.port, - binary_path: PathBuf::from(&ctx.binary_path), - binary_sha256: String::new(), - ancestors: ctx.ancestors.iter().map(PathBuf::from).collect(), - cmdline_paths: ctx.cmdline_paths.iter().map(PathBuf::from).collect(), - } -} - -/// Build the OCSF events describing a middleware chain outcome, in emission -/// order. Separated from `emit_middleware_events` so tests can assert on the -/// events deterministically without routing through the global tracing pipeline, -/// whose callsite-interest cache is process-global and races under parallel -/// tests. -fn middleware_events( - ctx: &L7EvalContext, - req: &crate::l7::provider::L7Request, - outcome: &openshell_supervisor_middleware::ChainOutcome, -) -> Vec { - let mut events = Vec::new(); - for invocation in &outcome.applied { - let allowed = invocation.decision == openshell_core::proto::Decision::Allow; - let mut event = HttpActivityBuilder::new(openshell_ocsf::ctx::ctx()) - .activity(ActivityId::Other) - .action(if allowed { - ActionId::Allowed - } else { - ActionId::Denied - }) - .disposition(if allowed { - DispositionId::Allowed - } else { - DispositionId::Blocked - }) - .severity(if allowed { - SeverityId::Informational - } else { - SeverityId::Medium - }) - .http_request(HttpRequest::new( - &req.action, - OcsfUrl::new("http", &ctx.host, &req.target, ctx.port), - )) - .dst_endpoint(Endpoint::from_domain(&ctx.host, ctx.port)) - .firewall_rule(&ctx.policy_name, "middleware") - .unmapped("transformed", invocation.transformed) - .unmapped("failed", invocation.failed) - .message(format!( - "MIDDLEWARE {} {} decision={:?}", - invocation.name, invocation.implementation, invocation.decision - )); - if !allowed && !outcome.reason.is_empty() { - event = event - .status(StatusId::Failure) - .status_detail(&outcome.reason); - } - let event = event.build(); - events.push(event); - - // A middleware that failed but was bypassed under `fail_open` is an - // enforcement failure operators must be able to alert on, even though the - // request proceeded. - if invocation.failed && allowed { - let event = DetectionFindingBuilder::new(openshell_ocsf::ctx::ctx()) - .severity(SeverityId::Medium) - .finding_info(FindingInfo::new( - "openshell.middleware.failure", - "Supervisor middleware failed open", - )) - .evidence_pairs(&[ - ("middleware", invocation.name.as_str()), - ("implementation", invocation.implementation.as_str()), - ]) - .unmapped("middleware", invocation.name.as_str()) - .unmapped("implementation", invocation.implementation.as_str()) - .message(format!( - "Middleware {} failed and was bypassed (fail_open)", - invocation.name - )) - .build(); - events.push(event); - } - } - if !outcome.allowed && outcome.reason.starts_with("middleware_failed:") { - let event = DetectionFindingBuilder::new(openshell_ocsf::ctx::ctx()) - .severity(SeverityId::High) - .finding_info(FindingInfo::new( - "openshell.middleware.failure", - "Supervisor middleware failure", - )) - .message("Required supervisor middleware failed closed") - .build(); - events.push(event); - } - for finding in &outcome.findings { - let event = DetectionFindingBuilder::new(openshell_ocsf::ctx::ctx()) - .severity(match finding.finding.severity.as_str() { - "high" => SeverityId::High, - "low" => SeverityId::Low, - _ => SeverityId::Medium, - }) - .finding_info(FindingInfo::new( - &finding.finding.r#type, - &finding.finding.label, - )) - .evidence_pairs(&[ - ("middleware", &finding.middleware), - ("count", &finding.finding.count.to_string()), - ]) - .unmapped("middleware", finding.middleware.as_str()) - .unmapped("count", finding.finding.count) - .message(format!( - "Middleware finding {} count={}", - finding.finding.r#type, finding.finding.count - )) - .build(); - events.push(event); - } - events -} - -/// Emit the OCSF events describing a middleware chain outcome through the -/// tracing pipeline. -fn emit_middleware_events( - ctx: &L7EvalContext, - req: &crate::l7::provider::L7Request, - outcome: &openshell_supervisor_middleware::ChainOutcome, -) { - for event in middleware_events(ctx, req, outcome) { - ocsf_emit!(event); - } -} - /// REST relay loop: parse request -> evaluate -> allow/deny -> relay response -> repeat. async fn relay_rest( config: &L7EndpointConfig, diff --git a/crates/openshell-supervisor-network/src/opa.rs b/crates/openshell-supervisor-network/src/opa.rs index f0aea9c38..8702c6e6e 100644 --- a/crates/openshell-supervisor-network/src/opa.rs +++ b/crates/openshell-supervisor-network/src/opa.rs @@ -227,6 +227,7 @@ impl OpaEngine { let mut data: serde_json::Value = serde_json::from_str(&data_json_str) .map_err(|e| miette::miette!("internal: failed to parse proto JSON: {e}"))?; + // Validate BEFORE expanding presets let (errors, warnings) = crate::l7::validate_l7_policies(&data); for w in &warnings { openshell_ocsf::ocsf_emit!( @@ -272,27 +273,7 @@ impl OpaEngine { /// `allow_network` rule, and returns a `PolicyDecision` with the result, /// deny reason, and matched policy name. pub fn evaluate_network(&self, input: &NetworkInput) -> Result { - let ancestor_strs: Vec = input - .ancestors - .iter() - .map(|p| p.to_string_lossy().into_owned()) - .collect(); - let cmdline_strs: Vec = input - .cmdline_paths - .iter() - .map(|p| p.to_string_lossy().into_owned()) - .collect(); - let input_json = serde_json::json!({ - "exec": { - "path": input.binary_path.to_string_lossy(), - "ancestors": ancestor_strs, - "cmdline_paths": cmdline_strs, - }, - "network": { - "host": input.host, - "port": input.port, - } - }); + let input_json = network_input_json(input); let mut engine = self .engine @@ -342,27 +323,7 @@ impl OpaEngine { &self, input: &NetworkInput, ) -> Result<(NetworkAction, u64)> { - let ancestor_strs: Vec = input - .ancestors - .iter() - .map(|p| p.to_string_lossy().into_owned()) - .collect(); - let cmdline_strs: Vec = input - .cmdline_paths - .iter() - .map(|p| p.to_string_lossy().into_owned()) - .collect(); - let input_json = serde_json::json!({ - "exec": { - "path": input.binary_path.to_string_lossy(), - "ancestors": ancestor_strs, - "cmdline_paths": cmdline_strs, - }, - "network": { - "host": input.host, - "port": input.port, - } - }); + let input_json = network_input_json(input); let mut engine = self .engine @@ -552,27 +513,7 @@ impl OpaEngine { &self, input: &NetworkInput, ) -> Result<(Vec, u64)> { - let ancestor_strs: Vec = input - .ancestors - .iter() - .map(|p| p.to_string_lossy().into_owned()) - .collect(); - let cmdline_strs: Vec = input - .cmdline_paths - .iter() - .map(|p| p.to_string_lossy().into_owned()) - .collect(); - let input_json = serde_json::json!({ - "exec": { - "path": input.binary_path.to_string_lossy(), - "ancestors": ancestor_strs, - "cmdline_paths": cmdline_strs, - }, - "network": { - "host": input.host, - "port": input.port, - } - }); + let input_json = network_input_json(input); let mut engine = self .engine @@ -630,27 +571,7 @@ impl OpaEngine { /// denial while preserving separate handling for `allowed_ips` and advisor /// proposals. pub fn query_exact_declared_endpoint_host(&self, input: &NetworkInput) -> Result { - let ancestor_strs: Vec = input - .ancestors - .iter() - .map(|p| p.to_string_lossy().into_owned()) - .collect(); - let cmdline_strs: Vec = input - .cmdline_paths - .iter() - .map(|p| p.to_string_lossy().into_owned()) - .collect(); - let input_json = serde_json::json!({ - "exec": { - "path": input.binary_path.to_string_lossy(), - "ancestors": ancestor_strs, - "cmdline_paths": cmdline_strs, - }, - "network": { - "host": input.host, - "port": input.port, - } - }); + let input_json = network_input_json(input); let mut engine = self .engine @@ -1711,7 +1632,7 @@ fn proto_to_opa_data_json(proto: &ProtoSandboxPolicy, entrypoint_pid: u32) -> St "middleware": mw.middleware, }); if let Some(config) = &mw.config { - value["config"] = prost_struct_to_json(config); + value["config"] = openshell_core::proto_struct::struct_to_json_value(config); } if !mw.on_error.is_empty() { value["on_error"] = mw.on_error.clone().into(); @@ -1740,32 +1661,6 @@ fn proto_to_opa_data_json(proto: &ProtoSandboxPolicy, entrypoint_pid: u32) -> St .to_string() } -fn prost_struct_to_json(config: &prost_types::Struct) -> serde_json::Value { - serde_json::Value::Object( - config - .fields - .iter() - .map(|(key, value)| (key.clone(), prost_value_to_json(value))) - .collect(), - ) -} - -fn prost_value_to_json(value: &prost_types::Value) -> serde_json::Value { - match value.kind.as_ref() { - Some(prost_types::value::Kind::NullValue(_)) | None => serde_json::Value::Null, - Some(prost_types::value::Kind::BoolValue(value)) => serde_json::Value::Bool(*value), - Some(prost_types::value::Kind::NumberValue(value)) => serde_json::Number::from_f64(*value) - .map_or(serde_json::Value::Null, serde_json::Value::Number), - Some(prost_types::value::Kind::StringValue(value)) => { - serde_json::Value::String(value.clone()) - } - Some(prost_types::value::Kind::ListValue(value)) => { - serde_json::Value::Array(value.values.iter().map(prost_value_to_json).collect()) - } - Some(prost_types::value::Kind::StructValue(value)) => prost_struct_to_json(value), - } -} - #[cfg(test)] #[allow( clippy::needless_raw_string_hashes, diff --git a/crates/openshell-supervisor-network/src/proxy.rs b/crates/openshell-supervisor-network/src/proxy.rs index 8616c3b2c..8a10ea3fd 100644 --- a/crates/openshell-supervisor-network/src/proxy.rs +++ b/crates/openshell-supervisor-network/src/proxy.rs @@ -4225,7 +4225,7 @@ async fn handle_forward_proxy( &upstream_target, forward_request_bytes, )?; - forward_request_bytes = match crate::l7::relay::apply_middleware_chain_for_scheme( + forward_request_bytes = match crate::l7::middleware::apply_middleware_chain_for_scheme( request, client, &l7_ctx, @@ -4236,8 +4236,8 @@ async fn handle_forward_proxy( ) .await? { - crate::l7::relay::MiddlewareApplyResult::Allowed(request) => request.raw_header, - crate::l7::relay::MiddlewareApplyResult::Denied(reason) => { + crate::l7::middleware::MiddlewareApplyResult::Allowed(request) => request.raw_header, + crate::l7::middleware::MiddlewareApplyResult::Denied(reason) => { emit_activity_simple(activity_tx, true, "middleware"); respond( client, From d2a9791308f969117363aad5654953dfb7f668e6 Mon Sep 17 00:00:00 2001 From: Piotr Mlocek Date: Fri, 3 Jul 2026 12:23:27 -0700 Subject: [PATCH 24/27] refactor(middleware): move shared contracts to core Signed-off-by: Piotr Mlocek --- Cargo.lock | 1 - architecture/sandbox.md | 4 +- crates/openshell-core/README.md | 7 ++ crates/openshell-core/src/lib.rs | 1 + crates/openshell-core/src/middleware.rs | 77 +++++++++++++++++++ crates/openshell-policy/Cargo.toml | 1 - crates/openshell-policy/src/lib.rs | 11 ++- .../src/builtins/mod.rs | 9 --- .../src/builtins/secrets.rs | 18 +---- .../src/lib.rs | 11 +-- .../src/l7/middleware.rs | 6 +- 11 files changed, 99 insertions(+), 47 deletions(-) create mode 100644 crates/openshell-core/src/middleware.rs diff --git a/Cargo.lock b/Cargo.lock index 6819aab32..3552ea04c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3772,7 +3772,6 @@ dependencies = [ "glob", "miette", "openshell-core", - "openshell-supervisor-middleware", "prost-types", "serde", "serde_json", diff --git a/architecture/sandbox.md b/architecture/sandbox.md index d63c4fbaa..bf09b3116 100644 --- a/architecture/sandbox.md +++ b/architecture/sandbox.md @@ -72,7 +72,9 @@ of the network rule that admitted the request. Built-ins run in-process; operator-registered services are called directly from the supervisor over the common middleware gRPC contract. The gateway validates external service capabilities and policy-owned config before delivery. Supervisors keep -the last-known-good service registry when a live config reload fails. +the last-known-good service registry when a live config reload fails. Built-in +middleware identifiers and pure config validation live in `openshell-core` so +policy admission does not depend on the supervisor runtime implementation. `https://inference.local` is special. It bypasses OPA network policy and is handled by the inference interception path: diff --git a/crates/openshell-core/README.md b/crates/openshell-core/README.md index 51da847b8..b40ae121c 100644 --- a/crates/openshell-core/README.md +++ b/crates/openshell-core/README.md @@ -50,3 +50,10 @@ router agree on provider defaults. Profiles define: Do not duplicate provider-specific inference behavior in callers. Add shared behavior here, then consume it from the gateway, sandbox, and router. + +## Middleware Contracts + +Built-in supervisor middleware identifiers and pure configuration validation +live in `openshell_core::middleware`. Policy admission and the supervisor +runtime consume the same contract without introducing a dependency from the +policy crate to the supervisor implementation. diff --git a/crates/openshell-core/src/lib.rs b/crates/openshell-core/src/lib.rs index 321296369..9ba041dd8 100644 --- a/crates/openshell-core/src/lib.rs +++ b/crates/openshell-core/src/lib.rs @@ -23,6 +23,7 @@ pub mod grpc_client; pub mod image; pub mod inference; pub mod metadata; +pub mod middleware; pub mod net; pub mod paths; pub mod policy; diff --git a/crates/openshell-core/src/middleware.rs b/crates/openshell-core/src/middleware.rs new file mode 100644 index 000000000..6fb15b1c1 --- /dev/null +++ b/crates/openshell-core/src/middleware.rs @@ -0,0 +1,77 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Shared supervisor middleware identifiers and policy validation contracts. + +use miette::{Result, miette}; + +/// Binding identifier for the built-in secret redaction middleware. +pub const BUILTIN_SECRETS: &str = "openshell/secrets"; + +/// Validate policy-owned configuration for a built-in middleware. +pub fn validate_builtin_config(implementation: &str, config: &prost_types::Struct) -> Result<()> { + match implementation { + BUILTIN_SECRETS => validate_secrets_config(config), + other => Err(miette!( + "middleware implementation '{other}' is not available in phase 1" + )), + } +} + +fn validate_secrets_config(config: &prost_types::Struct) -> Result<()> { + let mode = config + .fields + .get("secrets") + .and_then(|value| match value.kind.as_ref() { + Some(prost_types::value::Kind::StringValue(value)) => Some(value.as_str()), + _ => None, + }) + .unwrap_or("redact"); + if mode != "redact" { + return Err(miette!( + "{BUILTIN_SECRETS} only supports config.secrets: redact in phase 1" + )); + } + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn secrets_config_defaults_to_redact() { + validate_builtin_config(BUILTIN_SECRETS, &prost_types::Struct::default()).unwrap(); + } + + #[test] + fn secrets_config_rejects_unsupported_mode() { + let config = prost_types::Struct { + fields: std::iter::once(( + "secrets".to_string(), + prost_types::Value { + kind: Some(prost_types::value::Kind::StringValue("allow".into())), + }, + )) + .collect(), + }; + + let error = validate_builtin_config(BUILTIN_SECRETS, &config).unwrap_err(); + assert!( + error + .to_string() + .contains("only supports config.secrets: redact") + ); + } + + #[test] + fn rejects_unknown_builtin() { + let error = validate_builtin_config("openshell/unknown", &prost_types::Struct::default()) + .unwrap_err(); + assert!( + error + .to_string() + .contains("implementation 'openshell/unknown' is not available") + ); + } +} diff --git a/crates/openshell-policy/Cargo.toml b/crates/openshell-policy/Cargo.toml index 073728db1..036964b72 100644 --- a/crates/openshell-policy/Cargo.toml +++ b/crates/openshell-policy/Cargo.toml @@ -13,7 +13,6 @@ repository.workspace = true [dependencies] glob = { workspace = true } openshell-core = { path = "../openshell-core", default-features = false } -openshell-supervisor-middleware = { path = "../openshell-supervisor-middleware" } prost-types = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } diff --git a/crates/openshell-policy/src/lib.rs b/crates/openshell-policy/src/lib.rs index efcca2fca..f09006b2e 100644 --- a/crates/openshell-policy/src/lib.rs +++ b/crates/openshell-policy/src/lib.rs @@ -1532,7 +1532,7 @@ pub fn validate_sandbox_policy( reason: "implementation must not be empty".to_string(), }); } else if middleware.middleware.starts_with("openshell/") - && middleware.middleware != openshell_supervisor_middleware::BUILTIN_SECRETS + && middleware.middleware != openshell_core::middleware::BUILTIN_SECRETS { violations.push(PolicyViolation::InvalidMiddlewareConfig { name: middleware.name.clone(), @@ -1572,12 +1572,11 @@ pub fn validate_sandbox_policy( } } - if middleware.middleware == openshell_supervisor_middleware::BUILTIN_SECRETS { + if middleware.middleware == openshell_core::middleware::BUILTIN_SECRETS { let config = middleware.config.clone().unwrap_or_default(); - if let Err(error) = openshell_supervisor_middleware::validate_builtin_config( - &middleware.middleware, - &config, - ) { + if let Err(error) = + openshell_core::middleware::validate_builtin_config(&middleware.middleware, &config) + { violations.push(PolicyViolation::InvalidBuiltinMiddlewareConfig { name: middleware.name.clone(), reason: error.to_string(), diff --git a/crates/openshell-supervisor-middleware/src/builtins/mod.rs b/crates/openshell-supervisor-middleware/src/builtins/mod.rs index 1db620220..176f4786c 100644 --- a/crates/openshell-supervisor-middleware/src/builtins/mod.rs +++ b/crates/openshell-supervisor-middleware/src/builtins/mod.rs @@ -10,15 +10,6 @@ pub fn describe() -> Vec { vec![secrets::describe()] } -pub fn validate_config(binding_id: &str, config: &prost_types::Struct) -> Result<()> { - match binding_id { - secrets::BINDING_ID => secrets::validate_config(config), - other => Err(miette!( - "middleware implementation '{other}' is not available in phase 1" - )), - } -} - pub fn evaluate_http_request(evaluation: &HttpRequestEvaluation) -> Result { match evaluation.binding_id.as_str() { secrets::BINDING_ID => secrets::evaluate_http_request(evaluation), diff --git a/crates/openshell-supervisor-middleware/src/builtins/secrets.rs b/crates/openshell-supervisor-middleware/src/builtins/secrets.rs index d88ac080d..716301c44 100644 --- a/crates/openshell-supervisor-middleware/src/builtins/secrets.rs +++ b/crates/openshell-supervisor-middleware/src/builtins/secrets.rs @@ -10,7 +10,7 @@ use openshell_core::proto::{ }; use regex::Regex; -pub const BINDING_ID: &str = "openshell/secrets"; +pub const BINDING_ID: &str = openshell_core::middleware::BUILTIN_SECRETS; const OPERATION: &str = "HttpRequest"; const PHASE: &str = "pre_credentials"; const MAX_BODY_BYTES: u64 = 256 * 1024; @@ -54,21 +54,7 @@ static SECRET_PATTERNS: LazyLock<[SecretPattern; 2]> = LazyLock::new(|| { }); pub fn validate_config(config: &prost_types::Struct) -> Result<()> { - let mode = config - .fields - .get("secrets") - .and_then(|value| match value.kind.as_ref() { - Some(prost_types::value::Kind::StringValue(value)) => Some(value.as_str()), - _ => None, - }) - .unwrap_or("redact"); - if mode != "redact" { - return Err(miette!( - "{} only supports config.secrets: redact in phase 1", - BINDING_ID - )); - } - Ok(()) + openshell_core::middleware::validate_builtin_config(BINDING_ID, config) } pub fn evaluate_http_request(evaluation: &HttpRequestEvaluation) -> Result { diff --git a/crates/openshell-supervisor-middleware/src/lib.rs b/crates/openshell-supervisor-middleware/src/lib.rs index d2f8642b1..9b067b31e 100644 --- a/crates/openshell-supervisor-middleware/src/lib.rs +++ b/crates/openshell-supervisor-middleware/src/lib.rs @@ -11,6 +11,7 @@ use std::collections::{BTreeMap, HashMap, HashSet}; use std::sync::{Arc, LazyLock}; use miette::{Result, miette}; +pub use openshell_core::middleware::{BUILTIN_SECRETS, validate_builtin_config}; pub use service::InProcessMiddlewareService; use openshell_core::proto::middleware::v1::supervisor_middleware_server::SupervisorMiddleware; @@ -23,18 +24,8 @@ use tokio::sync::OnceCell; use tonic::Request; pub const API_VERSION: &str = "openshell.middleware.v1"; -pub const BUILTIN_SECRETS: &str = builtins::secrets::BINDING_ID; const HTTP_REQUEST_OPERATION: &str = "HttpRequest"; const PRE_CREDENTIALS_PHASE: &str = "pre_credentials"; - -/// Validate the configuration for an in-process middleware implementation. -/// -/// Policy admission uses this same implementation-specific validation before a -/// configuration can reach the request path. -pub fn validate_builtin_config(implementation: &str, config: &prost_types::Struct) -> Result<()> { - builtins::validate_config(implementation, config) -} - #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum OnError { FailClosed, diff --git a/crates/openshell-supervisor-network/src/l7/middleware.rs b/crates/openshell-supervisor-network/src/l7/middleware.rs index d8a8b6c30..a33d06e37 100644 --- a/crates/openshell-supervisor-network/src/l7/middleware.rs +++ b/crates/openshell-supervisor-network/src/l7/middleware.rs @@ -14,7 +14,7 @@ use std::collections::BTreeMap; use std::path::PathBuf; use tokio::io::{AsyncRead, AsyncWrite}; -pub(crate) enum MiddlewareApplyResult { +pub enum MiddlewareApplyResult { Allowed(crate::l7::provider::L7Request), Denied(String), } @@ -35,7 +35,7 @@ pub(super) fn middleware_chain_body_limit( .min() } -pub(crate) async fn apply_middleware_chain( +pub async fn apply_middleware_chain( req: crate::l7::provider::L7Request, client: &mut C, ctx: &L7EvalContext, @@ -47,7 +47,7 @@ pub(crate) async fn apply_middleware_chain( +pub async fn apply_middleware_chain_for_scheme( req: crate::l7::provider::L7Request, client: &mut C, ctx: &L7EvalContext, From 64439ea327b80bbc12bae145c4a8653450cfc5ec Mon Sep 17 00:00:00 2001 From: Piotr Mlocek Date: Fri, 3 Jul 2026 13:05:06 -0700 Subject: [PATCH 25/27] refactor(policy): extract middleware serialization Signed-off-by: Piotr Mlocek --- architecture/sandbox.md | 4 +- crates/openshell-cli/src/run.rs | 36 +-- crates/openshell-core/README.md | 7 + crates/openshell-core/src/proto_struct.rs | 67 ++++- crates/openshell-driver-docker/src/tests.rs | 39 +-- .../openshell-driver-kubernetes/src/driver.rs | 37 +-- .../openshell-driver-podman/src/container.rs | 39 +-- crates/openshell-driver-podman/src/driver.rs | 39 +-- crates/openshell-policy/src/lib.rs | 280 ++---------------- crates/openshell-policy/src/middleware.rs | 217 ++++++++++++++ 10 files changed, 337 insertions(+), 428 deletions(-) create mode 100644 crates/openshell-policy/src/middleware.rs diff --git a/architecture/sandbox.md b/architecture/sandbox.md index bf09b3116..a9f1208e0 100644 --- a/architecture/sandbox.md +++ b/architecture/sandbox.md @@ -74,7 +74,9 @@ over the common middleware gRPC contract. The gateway validates external service capabilities and policy-owned config before delivery. Supervisors keep the last-known-good service registry when a live config reload fails. Built-in middleware identifiers and pure config validation live in `openshell-core` so -policy admission does not depend on the supervisor runtime implementation. +policy admission does not depend on the supervisor runtime implementation. The +policy and runtime also share the core JSON/protobuf adapter for middleware +configuration, keeping serialization consistent across that boundary. `https://inference.local` is special. It bypasses OPA network policy and is handled by the inference interception path: diff --git a/crates/openshell-cli/src/run.rs b/crates/openshell-cli/src/run.rs index f56bb7151..4eaf429c4 100644 --- a/crates/openshell-cli/src/run.rs +++ b/crates/openshell-cli/src/run.rs @@ -1648,39 +1648,9 @@ fn parse_driver_config_json(value: &str) -> Result { )); }; - Ok(prost_types::Struct { - fields: fields - .into_iter() - .map(|(key, value)| json_to_protobuf_value(value).map(|value| (key, value))) - .collect::>()?, - }) -} - -fn json_to_protobuf_value(value: serde_json::Value) -> Result { - use prost_types::{ListValue, Struct, Value, value::Kind}; - - let kind = match value { - serde_json::Value::Null => Kind::NullValue(0), - serde_json::Value::Bool(value) => Kind::BoolValue(value), - serde_json::Value::Number(value) => Kind::NumberValue(value.as_f64().ok_or_else(|| { - miette!("--driver-config-json contains a number that cannot be represented") - })?), - serde_json::Value::String(value) => Kind::StringValue(value), - serde_json::Value::Array(values) => Kind::ListValue(ListValue { - values: values - .into_iter() - .map(json_to_protobuf_value) - .collect::>()?, - }), - serde_json::Value::Object(fields) => Kind::StructValue(Struct { - fields: fields - .into_iter() - .map(|(key, value)| json_to_protobuf_value(value).map(|value| (key, value))) - .collect::>()?, - }), - }; - - Ok(Value { kind: Some(kind) }) + openshell_core::proto_struct::json_object_to_struct(fields) + .into_diagnostic() + .wrap_err("--driver-config-json contains a value that cannot be represented") } fn validate_cpu_quantity(value: &str) -> Result { diff --git a/crates/openshell-core/README.md b/crates/openshell-core/README.md index b40ae121c..e27ab167f 100644 --- a/crates/openshell-core/README.md +++ b/crates/openshell-core/README.md @@ -57,3 +57,10 @@ Built-in supervisor middleware identifiers and pure configuration validation live in `openshell_core::middleware`. Policy admission and the supervisor runtime consume the same contract without introducing a dependency from the policy crate to the supervisor implementation. + +## Protobuf Struct Conversion + +Use `openshell_core::proto_struct` when crossing between `serde_json` values and +`prost_types::{Struct, Value}`. Both conversion directions live in this module; +JSON-to-protobuf conversion is fallible so callers cannot silently replace an +unrepresentable number. diff --git a/crates/openshell-core/src/proto_struct.rs b/crates/openshell-core/src/proto_struct.rs index 8871d6f6a..b90c52ca7 100644 --- a/crates/openshell-core/src/proto_struct.rs +++ b/crates/openshell-core/src/proto_struct.rs @@ -1,10 +1,57 @@ // SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -//! Helpers for decoding `google.protobuf.Struct` values. +//! Helpers for converting `google.protobuf.Struct` values to and from JSON. use serde::{Deserialize, Deserializer, de::Error as _}; +/// Errors converting JSON values into protobuf well-known types. +#[derive(Debug, thiserror::Error)] +pub enum ProtoStructError { + /// A JSON number cannot be represented by protobuf's double value. + #[error("JSON number {0} cannot be represented as a protobuf double")] + UnrepresentableNumber(serde_json::Number), +} + +/// Convert a JSON object into a protobuf Struct. +pub fn json_object_to_struct( + object: serde_json::Map, +) -> Result { + Ok(prost_types::Struct { + fields: object + .into_iter() + .map(|(key, value)| json_value_to_proto(value).map(|value| (key, value))) + .collect::>()?, + }) +} + +/// Convert a JSON value into a protobuf Value. +pub fn json_value_to_proto( + value: serde_json::Value, +) -> Result { + use prost_types::{ListValue, Value, value::Kind}; + + let kind = match value { + serde_json::Value::Null => Kind::NullValue(0), + serde_json::Value::Bool(value) => Kind::BoolValue(value), + serde_json::Value::Number(value) => Kind::NumberValue( + value + .as_f64() + .ok_or_else(|| ProtoStructError::UnrepresentableNumber(value.clone()))?, + ), + serde_json::Value::String(value) => Kind::StringValue(value), + serde_json::Value::Array(values) => Kind::ListValue(ListValue { + values: values + .into_iter() + .map(json_value_to_proto) + .collect::>()?, + }), + serde_json::Value::Object(object) => Kind::StructValue(json_object_to_struct(object)?), + }; + + Ok(Value { kind: Some(kind) }) +} + /// Convert a protobuf Struct into a JSON object for typed serde decoding. #[must_use] pub fn struct_to_json_object( @@ -72,6 +119,24 @@ where mod tests { use super::*; + #[test] + fn json_and_proto_values_round_trip() { + let json = serde_json::json!({ + "null": null, + "bool": true, + "number": 42.5, + "string": "value", + "list": [1.0, {"nested": "value"}], + }); + let serde_json::Value::Object(object) = json.clone() else { + unreachable!(); + }; + + let proto = json_object_to_struct(object).unwrap(); + + assert_eq!(struct_to_json_value(&proto), json); + } + #[derive(Debug, Default, Deserialize)] #[serde(default)] struct TestConfig { diff --git a/crates/openshell-driver-docker/src/tests.rs b/crates/openshell-driver-docker/src/tests.rs index d42d41099..bc37d87be 100644 --- a/crates/openshell-driver-docker/src/tests.rs +++ b/crates/openshell-driver-docker/src/tests.rs @@ -118,40 +118,11 @@ fn runtime_config() -> DockerDriverRuntimeConfig { } fn json_struct(value: serde_json::Value) -> prost_types::Struct { - match json_value(value).kind { - Some(prost_types::value::Kind::StructValue(value)) => value, - _ => panic!("expected JSON object"), - } -} - -fn json_value(value: serde_json::Value) -> prost_types::Value { - match value { - serde_json::Value::Null => prost_types::Value { kind: None }, - serde_json::Value::Bool(value) => prost_types::Value { - kind: Some(prost_types::value::Kind::BoolValue(value)), - }, - serde_json::Value::Number(value) => prost_types::Value { - kind: value.as_f64().map(prost_types::value::Kind::NumberValue), - }, - serde_json::Value::String(value) => prost_types::Value { - kind: Some(prost_types::value::Kind::StringValue(value)), - }, - serde_json::Value::Array(values) => prost_types::Value { - kind: Some(prost_types::value::Kind::ListValue( - prost_types::ListValue { - values: values.into_iter().map(json_value).collect(), - }, - )), - }, - serde_json::Value::Object(values) => prost_types::Value { - kind: Some(prost_types::value::Kind::StructValue(prost_types::Struct { - fields: values - .into_iter() - .map(|(key, value)| (key, json_value(value))) - .collect(), - })), - }, - } + let serde_json::Value::Object(object) = value else { + panic!("expected JSON object"); + }; + openshell_core::proto_struct::json_object_to_struct(object) + .expect("test JSON must convert to a protobuf Struct") } fn inspected_volume(driver: &str, options: HashMap) -> bollard::models::Volume { diff --git a/crates/openshell-driver-kubernetes/src/driver.rs b/crates/openshell-driver-kubernetes/src/driver.rs index 166f18b1c..eb34905c3 100644 --- a/crates/openshell-driver-kubernetes/src/driver.rs +++ b/crates/openshell-driver-kubernetes/src/driver.rs @@ -2200,38 +2200,11 @@ mod tests { std::sync::LazyLock::new(|| std::sync::Mutex::new(())); fn json_struct(value: serde_json::Value) -> Struct { - match json_value(value).kind { - Some(Kind::StructValue(value)) => value, - _ => panic!("expected JSON object"), - } - } - - fn json_value(value: serde_json::Value) -> Value { - match value { - serde_json::Value::Null => Value { kind: None }, - serde_json::Value::Bool(value) => Value { - kind: Some(Kind::BoolValue(value)), - }, - serde_json::Value::Number(value) => Value { - kind: value.as_f64().map(Kind::NumberValue), - }, - serde_json::Value::String(value) => Value { - kind: Some(Kind::StringValue(value)), - }, - serde_json::Value::Array(values) => Value { - kind: Some(Kind::ListValue(prost_types::ListValue { - values: values.into_iter().map(json_value).collect(), - })), - }, - serde_json::Value::Object(values) => Value { - kind: Some(Kind::StructValue(Struct { - fields: values - .into_iter() - .map(|(key, value)| (key, json_value(value))) - .collect(), - })), - }, - } + let serde_json::Value::Object(object) = value else { + panic!("expected JSON object"); + }; + openshell_core::proto_struct::json_object_to_struct(object) + .expect("test JSON must convert to a protobuf Struct") } fn kube_api_error(code: u16, message: &str) -> KubeError { diff --git a/crates/openshell-driver-podman/src/container.rs b/crates/openshell-driver-podman/src/container.rs index 7c9ab269f..73bbb5975 100644 --- a/crates/openshell-driver-podman/src/container.rs +++ b/crates/openshell-driver-podman/src/container.rs @@ -1128,40 +1128,11 @@ mod tests { std::sync::LazyLock::new(|| std::sync::Mutex::new(())); fn json_struct(value: Value) -> prost_types::Struct { - match json_value(value).kind { - Some(prost_types::value::Kind::StructValue(value)) => value, - _ => panic!("expected JSON object"), - } - } - - fn json_value(value: Value) -> prost_types::Value { - match value { - Value::Null => prost_types::Value { kind: None }, - Value::Bool(value) => prost_types::Value { - kind: Some(prost_types::value::Kind::BoolValue(value)), - }, - Value::Number(value) => prost_types::Value { - kind: value.as_f64().map(prost_types::value::Kind::NumberValue), - }, - Value::String(value) => prost_types::Value { - kind: Some(prost_types::value::Kind::StringValue(value)), - }, - Value::Array(values) => prost_types::Value { - kind: Some(prost_types::value::Kind::ListValue( - prost_types::ListValue { - values: values.into_iter().map(json_value).collect(), - }, - )), - }, - Value::Object(values) => prost_types::Value { - kind: Some(prost_types::value::Kind::StructValue(prost_types::Struct { - fields: values - .into_iter() - .map(|(key, value)| (key, json_value(value))) - .collect(), - })), - }, - } + let Value::Object(object) = value else { + panic!("expected JSON object"); + }; + proto_struct::json_object_to_struct(object) + .expect("test JSON must convert to a protobuf Struct") } fn gpu_resources(count: Option) -> ResourceRequirements { diff --git a/crates/openshell-driver-podman/src/driver.rs b/crates/openshell-driver-podman/src/driver.rs index 27d510281..a76f5485a 100644 --- a/crates/openshell-driver-podman/src/driver.rs +++ b/crates/openshell-driver-podman/src/driver.rs @@ -1274,41 +1274,12 @@ mod tests { PodmanComputeDriver::for_tests(config) } - fn json_value(value: serde_json::Value) -> prost_types::Value { - match value { - serde_json::Value::Null => prost_types::Value { kind: None }, - serde_json::Value::Bool(value) => prost_types::Value { - kind: Some(prost_types::value::Kind::BoolValue(value)), - }, - serde_json::Value::Number(value) => prost_types::Value { - kind: value.as_f64().map(prost_types::value::Kind::NumberValue), - }, - serde_json::Value::String(value) => prost_types::Value { - kind: Some(prost_types::value::Kind::StringValue(value)), - }, - serde_json::Value::Array(values) => prost_types::Value { - kind: Some(prost_types::value::Kind::ListValue( - prost_types::ListValue { - values: values.into_iter().map(json_value).collect(), - }, - )), - }, - serde_json::Value::Object(values) => prost_types::Value { - kind: Some(prost_types::value::Kind::StructValue(prost_types::Struct { - fields: values - .into_iter() - .map(|(key, value)| (key, json_value(value))) - .collect(), - })), - }, - } - } - fn json_struct(value: serde_json::Value) -> prost_types::Struct { - match json_value(value).kind { - Some(prost_types::value::Kind::StructValue(value)) => value, - _ => panic!("expected JSON object"), - } + let serde_json::Value::Object(object) = value else { + panic!("expected JSON object"); + }; + openshell_core::proto_struct::json_object_to_struct(object) + .expect("test JSON must convert to a protobuf Struct") } fn sandbox_with_volume_mount(volume: &str) -> DriverSandbox { diff --git a/crates/openshell-policy/src/lib.rs b/crates/openshell-policy/src/lib.rs index f09006b2e..d388e28a2 100644 --- a/crates/openshell-policy/src/lib.rs +++ b/crates/openshell-policy/src/lib.rs @@ -11,16 +11,17 @@ mod compose; mod merge; +mod middleware; -use std::collections::{BTreeMap, HashMap, HashSet}; +use std::collections::{BTreeMap, HashMap}; use std::fmt; use std::path::Path; use miette::{IntoDiagnostic, Result, WrapErr}; use openshell_core::proto::{ FilesystemPolicy, GraphqlOperation, L7Allow, L7DenyRule, L7QueryMatcher, L7Rule, - LandlockPolicy, McpOptions, MiddlewareEndpointSelector, NetworkBinary, NetworkEndpoint, - NetworkMiddlewareConfig, NetworkPolicyRule, ProcessPolicy, SandboxPolicy, + LandlockPolicy, McpOptions, NetworkBinary, NetworkEndpoint, NetworkPolicyRule, ProcessPolicy, + SandboxPolicy, }; use serde::{Deserialize, Serialize}; @@ -32,6 +33,7 @@ pub use merge::{ PolicyMergeError, PolicyMergeOp, PolicyMergeResult, PolicyMergeWarning, generated_rule_name, merge_policy, policy_covers_rule, }; +pub use middleware::middleware_host_matches; // --------------------------------------------------------------------------- // YAML serde types (canonical — used for both parsing and serialization) @@ -50,7 +52,7 @@ struct PolicyFile { #[serde(default, skip_serializing_if = "BTreeMap::is_empty")] network_policies: BTreeMap, #[serde(default, skip_serializing_if = "Vec::is_empty")] - network_middlewares: Vec, + network_middlewares: Vec, } #[derive(Debug, Serialize, Deserialize)] @@ -91,28 +93,6 @@ struct NetworkPolicyRuleDef { binaries: Vec, } -#[derive(Debug, Serialize, Deserialize)] -#[serde(deny_unknown_fields)] -struct NetworkMiddlewareConfigDef { - name: String, - middleware: String, - #[serde(default, skip_serializing_if = "BTreeMap::is_empty")] - config: BTreeMap, - #[serde(default, skip_serializing_if = "String::is_empty")] - on_error: String, - #[serde(default, skip_serializing_if = "Option::is_none")] - endpoints: Option, -} - -#[derive(Debug, Serialize, Deserialize)] -#[serde(deny_unknown_fields)] -struct MiddlewareEndpointSelectorDef { - #[serde(default, skip_serializing_if = "Vec::is_empty")] - include: Vec, - #[serde(default, skip_serializing_if = "Vec::is_empty")] - exclude: Vec, -} - #[derive(Debug, Serialize, Deserialize)] #[serde(deny_unknown_fields)] struct NetworkEndpointDef { @@ -695,21 +675,10 @@ fn yaml_mcp_method( method.to_string() } -fn to_proto(raw: PolicyFile) -> SandboxPolicy { - let network_middlewares = raw - .network_middlewares - .into_iter() - .map(|mw| NetworkMiddlewareConfig { - name: mw.name, - middleware: mw.middleware, - config: Some(json_map_to_struct(mw.config)), - on_error: mw.on_error, - endpoints: mw.endpoints.map(|selector| MiddlewareEndpointSelector { - include: selector.include, - exclude: selector.exclude, - }), - }) - .collect(); +fn to_proto(raw: PolicyFile) -> Result { + let network_middlewares = middleware::into_proto(raw.network_middlewares) + .into_diagnostic() + .wrap_err("failed to convert network middleware config")?; let network_policies = raw .network_policies @@ -800,7 +769,7 @@ fn to_proto(raw: PolicyFile) -> SandboxPolicy { }) .collect(); - SandboxPolicy { + Ok(SandboxPolicy { version: raw.version, filesystem: raw.filesystem_policy.map(|fs| FilesystemPolicy { include_workdir: fs.include_workdir, @@ -816,7 +785,7 @@ fn to_proto(raw: PolicyFile) -> SandboxPolicy { }), network_policies, network_middlewares, - } + }) } // --------------------------------------------------------------------------- @@ -948,27 +917,7 @@ fn from_proto(policy: &SandboxPolicy) -> PolicyFile { }) .collect(); - let network_middlewares = policy - .network_middlewares - .iter() - .map(|mw| NetworkMiddlewareConfigDef { - name: mw.name.clone(), - middleware: mw.middleware.clone(), - config: mw - .config - .as_ref() - .map(struct_to_json_map) - .unwrap_or_default(), - on_error: mw.on_error.clone(), - endpoints: mw - .endpoints - .as_ref() - .map(|selector| MiddlewareEndpointSelectorDef { - include: selector.include.clone(), - exclude: selector.exclude.clone(), - }), - }) - .collect(); + let network_middlewares = middleware::from_proto(&policy.network_middlewares); PolicyFile { version: policy.version, @@ -980,68 +929,6 @@ fn from_proto(policy: &SandboxPolicy) -> PolicyFile { } } -fn json_map_to_struct(map: BTreeMap) -> prost_types::Struct { - prost_types::Struct { - fields: map - .into_iter() - .map(|(key, value)| (key, json_to_protobuf_value(value))) - .collect(), - } -} - -fn json_to_protobuf_value(value: serde_json::Value) -> prost_types::Value { - use prost_types::{ListValue, Struct, Value, value::Kind}; - Value { - kind: Some(match value { - serde_json::Value::Null => Kind::NullValue(0), - serde_json::Value::Bool(value) => Kind::BoolValue(value), - serde_json::Value::Number(value) => { - Kind::NumberValue(value.as_f64().unwrap_or_default()) - } - serde_json::Value::String(value) => Kind::StringValue(value), - serde_json::Value::Array(values) => Kind::ListValue(ListValue { - values: values.into_iter().map(json_to_protobuf_value).collect(), - }), - serde_json::Value::Object(values) => Kind::StructValue(Struct { - fields: values - .into_iter() - .map(|(key, value)| (key, json_to_protobuf_value(value))) - .collect(), - }), - }), - } -} - -fn struct_to_json_map(config: &prost_types::Struct) -> BTreeMap { - config - .fields - .iter() - .map(|(key, value)| (key.clone(), protobuf_value_to_json(value))) - .collect() -} - -fn protobuf_value_to_json(value: &prost_types::Value) -> serde_json::Value { - match value.kind.as_ref() { - Some(prost_types::value::Kind::NullValue(_)) | None => serde_json::Value::Null, - Some(prost_types::value::Kind::BoolValue(value)) => serde_json::Value::Bool(*value), - Some(prost_types::value::Kind::NumberValue(value)) => serde_json::Number::from_f64(*value) - .map_or(serde_json::Value::Null, serde_json::Value::Number), - Some(prost_types::value::Kind::StringValue(value)) => { - serde_json::Value::String(value.clone()) - } - Some(prost_types::value::Kind::ListValue(value)) => { - serde_json::Value::Array(value.values.iter().map(protobuf_value_to_json).collect()) - } - Some(prost_types::value::Kind::StructValue(value)) => serde_json::Value::Object( - value - .fields - .iter() - .map(|(key, value)| (key.clone(), protobuf_value_to_json(value))) - .collect(), - ), - } -} - // --------------------------------------------------------------------------- // Sandbox UID/GID constants // --------------------------------------------------------------------------- @@ -1086,7 +973,7 @@ pub fn parse_sandbox_policy(yaml: &str) -> Result { let raw: PolicyFile = serde_yml::from_str(yaml) .into_diagnostic() .wrap_err("failed to parse sandbox policy YAML")?; - Ok(to_proto(raw)) + to_proto(raw) } /// Serialize a proto sandbox policy to a YAML string. @@ -1343,45 +1230,6 @@ impl fmt::Display for PolicyViolation { } } -/// Match a middleware host selector pattern using the runtime's glob semantics. -/// -/// Invalid or empty patterns return an error instead of silently becoming a -/// non-match. -pub fn middleware_host_matches(pattern: &str, host: &str) -> std::result::Result { - if pattern.is_empty() { - return Err("host pattern must not be empty".to_string()); - } - if pattern.chars().any(char::is_whitespace) { - return Err("host pattern must not contain whitespace".to_string()); - } - - let pattern = glob::Pattern::new(&pattern.to_ascii_lowercase()) - .map_err(|error| format!("invalid host pattern: {error}"))?; - Ok(pattern.matches(&host.to_ascii_lowercase())) -} - -fn middleware_selector_matches_host( - middleware: &NetworkMiddlewareConfig, - host: &str, -) -> std::result::Result { - let Some(selector) = &middleware.endpoints else { - return Ok(false); - }; - let matches_include = selector - .include - .iter() - .try_fold(false, |matched, pattern| { - middleware_host_matches(pattern, host).map(|matches| matched || matches) - })?; - let matches_exclude = selector - .exclude - .iter() - .try_fold(false, |matched, pattern| { - middleware_host_matches(pattern, host).map(|matches| matched || matches) - })?; - Ok(matches_include && !matches_exclude) -} - /// Validate that a sandbox policy does not contain unsafe content. /// /// Returns `Ok(())` if the policy is safe, or `Err(violations)` listing all @@ -1513,96 +1361,7 @@ pub fn validate_sandbox_policy( } } - let mut middleware_names = HashSet::new(); - for middleware in &policy.network_middlewares { - if middleware.name.is_empty() { - violations.push(PolicyViolation::InvalidMiddlewareConfig { - name: middleware.name.clone(), - reason: "name must not be empty".to_string(), - }); - } else if !middleware_names.insert(middleware.name.clone()) { - violations.push(PolicyViolation::DuplicateMiddlewareConfigName { - name: middleware.name.clone(), - }); - } - - if middleware.middleware.is_empty() { - violations.push(PolicyViolation::InvalidMiddlewareConfig { - name: middleware.name.clone(), - reason: "implementation must not be empty".to_string(), - }); - } else if middleware.middleware.starts_with("openshell/") - && middleware.middleware != openshell_core::middleware::BUILTIN_SECRETS - { - violations.push(PolicyViolation::InvalidMiddlewareConfig { - name: middleware.name.clone(), - reason: format!("unsupported built-in '{}'", middleware.middleware), - }); - } - - if !matches!( - middleware.on_error.as_str(), - "" | "fail_closed" | "fail_open" - ) { - violations.push(PolicyViolation::InvalidMiddlewareConfig { - name: middleware.name.clone(), - reason: format!("invalid on_error '{}'", middleware.on_error), - }); - } - - let Some(selector) = &middleware.endpoints else { - violations.push(PolicyViolation::InvalidMiddlewareConfig { - name: middleware.name.clone(), - reason: "endpoint selector is required".to_string(), - }); - continue; - }; - if selector.include.is_empty() { - violations.push(PolicyViolation::InvalidMiddlewareConfig { - name: middleware.name.clone(), - reason: "endpoint selector must include at least one host pattern".to_string(), - }); - } - for pattern in selector.include.iter().chain(&selector.exclude) { - if let Err(reason) = middleware_host_matches(pattern, "validation.invalid") { - violations.push(PolicyViolation::InvalidMiddlewareConfig { - name: middleware.name.clone(), - reason: format!("endpoint selector pattern '{pattern}' is invalid: {reason}"), - }); - } - } - - if middleware.middleware == openshell_core::middleware::BUILTIN_SECRETS { - let config = middleware.config.clone().unwrap_or_default(); - if let Err(error) = - openshell_core::middleware::validate_builtin_config(&middleware.middleware, &config) - { - violations.push(PolicyViolation::InvalidBuiltinMiddlewareConfig { - name: middleware.name.clone(), - reason: error.to_string(), - }); - } - } - - for (key, rule) in &policy.network_policies { - let policy_name = if rule.name.is_empty() { - key - } else { - &rule.name - }; - for endpoint in &rule.endpoints { - if endpoint.tls == "skip" - && middleware_selector_matches_host(middleware, &endpoint.host).unwrap_or(false) - { - violations.push(PolicyViolation::MiddlewareTlsSkipConflict { - middleware_name: middleware.name.clone(), - policy_name: policy_name.clone(), - host: endpoint.host.clone(), - }); - } - } - } - } + violations.extend(middleware::validate(policy)); if violations.is_empty() { Ok(()) @@ -2019,13 +1778,16 @@ network_policies: // ---- Policy validation tests ---- - fn middleware_config(name: &str, implementation: &str) -> NetworkMiddlewareConfig { - NetworkMiddlewareConfig { + fn middleware_config( + name: &str, + implementation: &str, + ) -> openshell_core::proto::NetworkMiddlewareConfig { + openshell_core::proto::NetworkMiddlewareConfig { name: name.into(), middleware: implementation.into(), config: None, on_error: String::new(), - endpoints: Some(MiddlewareEndpointSelector { + endpoints: Some(openshell_core::proto::MiddlewareEndpointSelector { include: vec!["api.example.com".into()], exclude: Vec::new(), }), diff --git a/crates/openshell-policy/src/middleware.rs b/crates/openshell-policy/src/middleware.rs new file mode 100644 index 000000000..ef10cb70f --- /dev/null +++ b/crates/openshell-policy/src/middleware.rs @@ -0,0 +1,217 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! YAML schema and protobuf conversion for supervisor middleware policies. + +use std::collections::{BTreeMap, HashSet}; + +use openshell_core::proto::{MiddlewareEndpointSelector, NetworkMiddlewareConfig, SandboxPolicy}; +use openshell_core::proto_struct::{ + ProtoStructError, json_object_to_struct, struct_to_json_object, +}; +use serde::{Deserialize, Serialize}; + +use super::PolicyViolation; + +#[derive(Debug, Serialize, Deserialize)] +#[serde(deny_unknown_fields)] +pub struct NetworkMiddlewareConfigDef { + name: String, + middleware: String, + #[serde(default, skip_serializing_if = "BTreeMap::is_empty")] + config: BTreeMap, + #[serde(default, skip_serializing_if = "String::is_empty")] + on_error: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + endpoints: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(deny_unknown_fields)] +struct MiddlewareEndpointSelectorDef { + #[serde(default, skip_serializing_if = "Vec::is_empty")] + include: Vec, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + exclude: Vec, +} + +pub fn into_proto( + definitions: Vec, +) -> Result, ProtoStructError> { + definitions + .into_iter() + .map(|definition| { + Ok(NetworkMiddlewareConfig { + name: definition.name, + middleware: definition.middleware, + config: Some(json_object_to_struct( + definition.config.into_iter().collect(), + )?), + on_error: definition.on_error, + endpoints: definition + .endpoints + .map(|selector| MiddlewareEndpointSelector { + include: selector.include, + exclude: selector.exclude, + }), + }) + }) + .collect() +} + +pub fn from_proto(middlewares: &[NetworkMiddlewareConfig]) -> Vec { + middlewares + .iter() + .map(|middleware| NetworkMiddlewareConfigDef { + name: middleware.name.clone(), + middleware: middleware.middleware.clone(), + config: middleware + .config + .as_ref() + .map(struct_to_json_object) + .unwrap_or_default() + .into_iter() + .collect(), + on_error: middleware.on_error.clone(), + endpoints: middleware.endpoints.as_ref().map(|selector| { + MiddlewareEndpointSelectorDef { + include: selector.include.clone(), + exclude: selector.exclude.clone(), + } + }), + }) + .collect() +} + +/// Match a middleware host selector pattern using the runtime's glob semantics. +/// +/// Invalid or empty patterns return an error instead of silently becoming a +/// non-match. +pub fn middleware_host_matches(pattern: &str, host: &str) -> Result { + if pattern.is_empty() { + return Err("host pattern must not be empty".to_string()); + } + if pattern.chars().any(char::is_whitespace) { + return Err("host pattern must not contain whitespace".to_string()); + } + + let pattern = glob::Pattern::new(&pattern.to_ascii_lowercase()) + .map_err(|error| format!("invalid host pattern: {error}"))?; + Ok(pattern.matches(&host.to_ascii_lowercase())) +} + +fn selector_matches_host(middleware: &NetworkMiddlewareConfig, host: &str) -> Result { + let Some(selector) = &middleware.endpoints else { + return Ok(false); + }; + let matches_include = selector + .include + .iter() + .try_fold(false, |matched, pattern| { + middleware_host_matches(pattern, host).map(|matches| matched || matches) + })?; + let matches_exclude = selector + .exclude + .iter() + .try_fold(false, |matched, pattern| { + middleware_host_matches(pattern, host).map(|matches| matched || matches) + })?; + Ok(matches_include && !matches_exclude) +} + +pub fn validate(policy: &SandboxPolicy) -> Vec { + let mut violations = Vec::new(); + let mut names = HashSet::new(); + + for middleware in &policy.network_middlewares { + if middleware.name.is_empty() { + violations.push(PolicyViolation::InvalidMiddlewareConfig { + name: middleware.name.clone(), + reason: "name must not be empty".to_string(), + }); + } else if !names.insert(middleware.name.clone()) { + violations.push(PolicyViolation::DuplicateMiddlewareConfigName { + name: middleware.name.clone(), + }); + } + + if middleware.middleware.is_empty() { + violations.push(PolicyViolation::InvalidMiddlewareConfig { + name: middleware.name.clone(), + reason: "implementation must not be empty".to_string(), + }); + } else if middleware.middleware.starts_with("openshell/") + && middleware.middleware != openshell_core::middleware::BUILTIN_SECRETS + { + violations.push(PolicyViolation::InvalidMiddlewareConfig { + name: middleware.name.clone(), + reason: format!("unsupported built-in '{}'", middleware.middleware), + }); + } + + if !matches!( + middleware.on_error.as_str(), + "" | "fail_closed" | "fail_open" + ) { + violations.push(PolicyViolation::InvalidMiddlewareConfig { + name: middleware.name.clone(), + reason: format!("invalid on_error '{}'", middleware.on_error), + }); + } + + let Some(selector) = &middleware.endpoints else { + violations.push(PolicyViolation::InvalidMiddlewareConfig { + name: middleware.name.clone(), + reason: "endpoint selector is required".to_string(), + }); + continue; + }; + if selector.include.is_empty() { + violations.push(PolicyViolation::InvalidMiddlewareConfig { + name: middleware.name.clone(), + reason: "endpoint selector must include at least one host pattern".to_string(), + }); + } + for pattern in selector.include.iter().chain(&selector.exclude) { + if let Err(reason) = middleware_host_matches(pattern, "validation.invalid") { + violations.push(PolicyViolation::InvalidMiddlewareConfig { + name: middleware.name.clone(), + reason: format!("endpoint selector pattern '{pattern}' is invalid: {reason}"), + }); + } + } + + if middleware.middleware == openshell_core::middleware::BUILTIN_SECRETS { + let config = middleware.config.clone().unwrap_or_default(); + if let Err(error) = + openshell_core::middleware::validate_builtin_config(&middleware.middleware, &config) + { + violations.push(PolicyViolation::InvalidBuiltinMiddlewareConfig { + name: middleware.name.clone(), + reason: error.to_string(), + }); + } + } + + for (key, rule) in &policy.network_policies { + let policy_name = if rule.name.is_empty() { + key + } else { + &rule.name + }; + for endpoint in &rule.endpoints { + if endpoint.tls == "skip" + && selector_matches_host(middleware, &endpoint.host).unwrap_or(false) + { + violations.push(PolicyViolation::MiddlewareTlsSkipConflict { + middleware_name: middleware.name.clone(), + policy_name: policy_name.clone(), + host: endpoint.host.clone(), + }); + } + } + } + } + + violations +} From 3b5875965266961bfc5b06ea9f5c6bc5575a7355 Mon Sep 17 00:00:00 2001 From: Piotr Mlocek Date: Fri, 3 Jul 2026 13:14:45 -0700 Subject: [PATCH 26/27] refactor(middleware): move host matching to core Signed-off-by: Piotr Mlocek --- Cargo.lock | 3 +- architecture/sandbox.md | 9 +++--- crates/openshell-core/Cargo.toml | 1 + crates/openshell-core/README.md | 8 ++--- crates/openshell-core/src/middleware.rs | 31 +++++++++++++++++++ crates/openshell-policy/Cargo.toml | 2 -- crates/openshell-policy/src/lib.rs | 16 +++++----- crates/openshell-policy/src/middleware.rs | 19 ++---------- .../openshell-supervisor-network/src/opa.rs | 10 +++--- 9 files changed, 56 insertions(+), 43 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 3552ea04c..46a081025 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3630,6 +3630,7 @@ version = "0.0.0" dependencies = [ "base64 0.22.1", "chrono", + "glob", "ipnet", "miette", "prost", @@ -3769,10 +3770,8 @@ dependencies = [ name = "openshell-policy" version = "0.0.0" dependencies = [ - "glob", "miette", "openshell-core", - "prost-types", "serde", "serde_json", "serde_yml", diff --git a/architecture/sandbox.md b/architecture/sandbox.md index a9f1208e0..9e984774b 100644 --- a/architecture/sandbox.md +++ b/architecture/sandbox.md @@ -73,10 +73,11 @@ operator-registered services are called directly from the supervisor over the common middleware gRPC contract. The gateway validates external service capabilities and policy-owned config before delivery. Supervisors keep the last-known-good service registry when a live config reload fails. Built-in -middleware identifiers and pure config validation live in `openshell-core` so -policy admission does not depend on the supervisor runtime implementation. The -policy and runtime also share the core JSON/protobuf adapter for middleware -configuration, keeping serialization consistent across that boundary. +middleware identifiers, host-selector matching, and pure config validation live +in `openshell-core` so policy admission does not depend on the supervisor +runtime implementation. The policy and runtime also share the core +JSON/protobuf adapter for middleware configuration, keeping serialization +consistent across that boundary. `https://inference.local` is special. It bypasses OPA network policy and is handled by the inference interception path: diff --git a/crates/openshell-core/Cargo.toml b/crates/openshell-core/Cargo.toml index 0ff6d06d6..53735a6bb 100644 --- a/crates/openshell-core/Cargo.toml +++ b/crates/openshell-core/Cargo.toml @@ -11,6 +11,7 @@ license.workspace = true repository.workspace = true [dependencies] +glob = { workspace = true } prost = { workspace = true } prost-types = { workspace = true } tonic = { workspace = true, features = ["channel", "tls-native-roots"] } diff --git a/crates/openshell-core/README.md b/crates/openshell-core/README.md index e27ab167f..b4cdb05f5 100644 --- a/crates/openshell-core/README.md +++ b/crates/openshell-core/README.md @@ -53,10 +53,10 @@ behavior here, then consume it from the gateway, sandbox, and router. ## Middleware Contracts -Built-in supervisor middleware identifiers and pure configuration validation -live in `openshell_core::middleware`. Policy admission and the supervisor -runtime consume the same contract without introducing a dependency from the -policy crate to the supervisor implementation. +Built-in supervisor middleware identifiers, host-selector matching, and pure +configuration validation live in `openshell_core::middleware`. Policy admission +and the supervisor runtime consume the same contract without introducing a +dependency from the policy crate to the supervisor implementation. ## Protobuf Struct Conversion diff --git a/crates/openshell-core/src/middleware.rs b/crates/openshell-core/src/middleware.rs index 6fb15b1c1..1688cd4f9 100644 --- a/crates/openshell-core/src/middleware.rs +++ b/crates/openshell-core/src/middleware.rs @@ -8,6 +8,23 @@ use miette::{Result, miette}; /// Binding identifier for the built-in secret redaction middleware. pub const BUILTIN_SECRETS: &str = "openshell/secrets"; +/// Match a middleware host selector pattern using the runtime's glob semantics. +/// +/// Matching is case-insensitive. Invalid or empty patterns return an error +/// instead of silently becoming a non-match. +pub fn host_matches(pattern: &str, host: &str) -> std::result::Result { + if pattern.is_empty() { + return Err("host pattern must not be empty".to_string()); + } + if pattern.chars().any(char::is_whitespace) { + return Err("host pattern must not contain whitespace".to_string()); + } + + let pattern = glob::Pattern::new(&pattern.to_ascii_lowercase()) + .map_err(|error| format!("invalid host pattern: {error}"))?; + Ok(pattern.matches(&host.to_ascii_lowercase())) +} + /// Validate policy-owned configuration for a built-in middleware. pub fn validate_builtin_config(implementation: &str, config: &prost_types::Struct) -> Result<()> { match implementation { @@ -39,6 +56,20 @@ fn validate_secrets_config(config: &prost_types::Struct) -> Result<()> { mod tests { use super::*; + #[test] + fn host_matching_is_case_insensitive() { + assert!(host_matches("*.Example.COM", "API.example.com").unwrap()); + assert!(!host_matches("*.example.com", "example.com").unwrap()); + assert!(host_matches("*", "deep.api.example.com").unwrap()); + } + + #[test] + fn host_matching_rejects_invalid_patterns() { + assert!(host_matches("", "api.example.com").is_err()); + assert!(host_matches("api .example.com", "api.example.com").is_err()); + assert!(host_matches("api[.example.com", "api.example.com").is_err()); + } + #[test] fn secrets_config_defaults_to_redact() { validate_builtin_config(BUILTIN_SECRETS, &prost_types::Struct::default()).unwrap(); diff --git a/crates/openshell-policy/Cargo.toml b/crates/openshell-policy/Cargo.toml index 036964b72..16719de13 100644 --- a/crates/openshell-policy/Cargo.toml +++ b/crates/openshell-policy/Cargo.toml @@ -11,9 +11,7 @@ license.workspace = true repository.workspace = true [dependencies] -glob = { workspace = true } openshell-core = { path = "../openshell-core", default-features = false } -prost-types = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } serde_yml = { workspace = true } diff --git a/crates/openshell-policy/src/lib.rs b/crates/openshell-policy/src/lib.rs index d388e28a2..d94fd9002 100644 --- a/crates/openshell-policy/src/lib.rs +++ b/crates/openshell-policy/src/lib.rs @@ -1826,15 +1826,13 @@ network_policies: fn validate_rejects_invalid_builtin_middleware_config() { let mut policy = restrictive_default_policy(); let mut middleware = middleware_config("redact-secrets", "openshell/secrets"); - middleware.config = Some(prost_types::Struct { - fields: std::iter::once(( - "secrets".into(), - prost_types::Value { - kind: Some(prost_types::value::Kind::StringValue("allow".into())), - }, - )) - .collect(), - }); + middleware.config = Some( + openshell_core::proto_struct::json_object_to_struct( + std::iter::once(("secrets".into(), serde_json::Value::String("allow".into()))) + .collect(), + ) + .unwrap(), + ); policy.network_middlewares.push(middleware); let violations = validate_sandbox_policy(&policy).expect_err("invalid config"); diff --git a/crates/openshell-policy/src/middleware.rs b/crates/openshell-policy/src/middleware.rs index ef10cb70f..1e795f600 100644 --- a/crates/openshell-policy/src/middleware.rs +++ b/crates/openshell-policy/src/middleware.rs @@ -13,6 +13,8 @@ use serde::{Deserialize, Serialize}; use super::PolicyViolation; +pub use openshell_core::middleware::host_matches as middleware_host_matches; + #[derive(Debug, Serialize, Deserialize)] #[serde(deny_unknown_fields)] pub struct NetworkMiddlewareConfigDef { @@ -83,23 +85,6 @@ pub fn from_proto(middlewares: &[NetworkMiddlewareConfig]) -> Vec Result { - if pattern.is_empty() { - return Err("host pattern must not be empty".to_string()); - } - if pattern.chars().any(char::is_whitespace) { - return Err("host pattern must not contain whitespace".to_string()); - } - - let pattern = glob::Pattern::new(&pattern.to_ascii_lowercase()) - .map_err(|error| format!("invalid host pattern: {error}"))?; - Ok(pattern.matches(&host.to_ascii_lowercase())) -} - fn selector_matches_host(middleware: &NetworkMiddlewareConfig, host: &str) -> Result { let Some(selector) = &middleware.endpoints else { return Ok(false); diff --git a/crates/openshell-supervisor-network/src/opa.rs b/crates/openshell-supervisor-network/src/opa.rs index 8702c6e6e..c34ea9f61 100644 --- a/crates/openshell-supervisor-network/src/opa.rs +++ b/crates/openshell-supervisor-network/src/opa.rs @@ -741,14 +741,14 @@ fn middleware_selector_matches(config: ®orus::Value, host: &str) -> Result Vec { } for pattern in includes.iter().chain(&excludes) { if let Err(error) = - openshell_policy::middleware_host_matches(pattern, "validation.invalid") + openshell_core::middleware::host_matches(pattern, "validation.invalid") { errors.push(format!( "middleware config '{name}' has invalid endpoint selector pattern '{pattern}': {error}" @@ -1220,10 +1220,10 @@ fn global_selector_matches_any_middleware( let excludes = json_string_array(selector.get("exclude")); !includes.is_empty() && includes.iter().any(|pattern| { - openshell_policy::middleware_host_matches(pattern, host).unwrap_or(false) + openshell_core::middleware::host_matches(pattern, host).unwrap_or(false) }) && !excludes.iter().any(|pattern| { - openshell_policy::middleware_host_matches(pattern, host).unwrap_or(false) + openshell_core::middleware::host_matches(pattern, host).unwrap_or(false) }) }) } From 3b6fc104af78f8a8d2c92958c8816f0f02f18d06 Mon Sep 17 00:00:00 2001 From: Piotr Mlocek Date: Fri, 3 Jul 2026 14:31:46 -0700 Subject: [PATCH 27/27] feat(middleware): align operations and ordering Signed-off-by: Piotr Mlocek --- architecture/sandbox.md | 4 +- crates/openshell-policy/src/lib.rs | 3 + crates/openshell-policy/src/middleware.rs | 8 ++ .../src/builtins/secrets.rs | 7 +- .../src/lib.rs | 83 +++++++++++++++---- .../src/l7/relay.rs | 3 + .../src/l7/rest.rs | 1 + .../openshell-supervisor-network/src/opa.rs | 16 +++- docs/extensibility/supervisor-middleware.mdx | 9 +- docs/reference/policy-schema.mdx | 4 +- docs/sandboxes/policies.mdx | 6 +- proto/middleware.proto | 22 +++-- proto/sandbox.proto | 2 + 13 files changed, 135 insertions(+), 33 deletions(-) diff --git a/architecture/sandbox.md b/architecture/sandbox.md index 9e984774b..bc3982343 100644 --- a/architecture/sandbox.md +++ b/architecture/sandbox.md @@ -68,7 +68,9 @@ are relayed but are not currently parsed for policy enforcement. For admitted HTTP requests, the proxy can run an ordered supervisor middleware chain before credential injection. Host selectors choose the chain independently -of the network rule that admitted the request. Built-ins run in-process; +of the network rule that admitted the request. Policy entries use integer order +values with stable name tie-breaking, and the gRPC contract represents operations +and phases as enums. Built-ins run in-process; operator-registered services are called directly from the supervisor over the common middleware gRPC contract. The gateway validates external service capabilities and policy-owned config before delivery. Supervisors keep diff --git a/crates/openshell-policy/src/lib.rs b/crates/openshell-policy/src/lib.rs index d94fd9002..b83e07d71 100644 --- a/crates/openshell-policy/src/lib.rs +++ b/crates/openshell-policy/src/lib.rs @@ -1493,6 +1493,7 @@ version: 1 network_middlewares: - name: global-redactor middleware: openshell/secrets + order: 20 on_error: fail_open endpoints: include: ["api.example.com", "*.service.test"] @@ -1520,6 +1521,7 @@ network_policies: assert_eq!(proto.network_middlewares.len(), 2); assert_eq!(proto.network_middlewares[0].name, "global-redactor"); assert_eq!(proto.network_middlewares[0].middleware, "openshell/secrets"); + assert_eq!(proto.network_middlewares[0].order, 20); assert_eq!(proto.network_middlewares[0].on_error, "fail_open"); assert_eq!( proto.network_middlewares[0] @@ -1785,6 +1787,7 @@ network_policies: openshell_core::proto::NetworkMiddlewareConfig { name: name.into(), middleware: implementation.into(), + order: 0, config: None, on_error: String::new(), endpoints: Some(openshell_core::proto::MiddlewareEndpointSelector { diff --git a/crates/openshell-policy/src/middleware.rs b/crates/openshell-policy/src/middleware.rs index 1e795f600..71c5ac9a7 100644 --- a/crates/openshell-policy/src/middleware.rs +++ b/crates/openshell-policy/src/middleware.rs @@ -20,6 +20,8 @@ pub use openshell_core::middleware::host_matches as middleware_host_matches; pub struct NetworkMiddlewareConfigDef { name: String, middleware: String, + #[serde(default, skip_serializing_if = "is_default")] + order: i32, #[serde(default, skip_serializing_if = "BTreeMap::is_empty")] config: BTreeMap, #[serde(default, skip_serializing_if = "String::is_empty")] @@ -28,6 +30,10 @@ pub struct NetworkMiddlewareConfigDef { endpoints: Option, } +fn is_default(value: &T) -> bool { + value == &T::default() +} + #[derive(Debug, Serialize, Deserialize)] #[serde(deny_unknown_fields)] struct MiddlewareEndpointSelectorDef { @@ -46,6 +52,7 @@ pub fn into_proto( Ok(NetworkMiddlewareConfig { name: definition.name, middleware: definition.middleware, + order: definition.order, config: Some(json_object_to_struct( definition.config.into_iter().collect(), )?), @@ -67,6 +74,7 @@ pub fn from_proto(middlewares: &[NetworkMiddlewareConfig]) -> Vec MiddlewareBinding { MiddlewareBinding { id: BINDING_ID.into(), - operation: OPERATION.into(), - phase: PHASE.into(), + operation: SupervisorMiddlewareOperation::HttpRequest as i32, + phase: SupervisorMiddlewarePhase::PreCredentials as i32, max_body_bytes: MAX_BODY_BYTES, } } diff --git a/crates/openshell-supervisor-middleware/src/lib.rs b/crates/openshell-supervisor-middleware/src/lib.rs index 9b067b31e..5fc15a0dd 100644 --- a/crates/openshell-supervisor-middleware/src/lib.rs +++ b/crates/openshell-supervisor-middleware/src/lib.rs @@ -18,14 +18,16 @@ use openshell_core::proto::middleware::v1::supervisor_middleware_server::Supervi use openshell_core::proto::{ Decision, Finding, HttpRequestEvaluation, HttpRequestTarget, MiddlewareBinding, MiddlewareManifest, NetworkMiddlewareConfig, RequestContext, SandboxPolicy, - SupervisorMiddlewareService, ValidateConfigRequest, + SupervisorMiddlewareOperation, SupervisorMiddlewarePhase, SupervisorMiddlewareService, + ValidateConfigRequest, }; use tokio::sync::OnceCell; use tonic::Request; pub const API_VERSION: &str = "openshell.middleware.v1"; -const HTTP_REQUEST_OPERATION: &str = "HttpRequest"; -const PRE_CREDENTIALS_PHASE: &str = "pre_credentials"; +const HTTP_REQUEST_OPERATION: SupervisorMiddlewareOperation = + SupervisorMiddlewareOperation::HttpRequest; +const PRE_CREDENTIALS_PHASE: SupervisorMiddlewarePhase = SupervisorMiddlewarePhase::PreCredentials; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum OnError { FailClosed, @@ -48,6 +50,7 @@ impl OnError { pub struct ChainEntry { pub name: String, pub implementation: String, + pub order: i32, pub config: prost_types::Struct, pub on_error: OnError, } @@ -68,6 +71,7 @@ impl TryFrom<&NetworkMiddlewareConfig> for ChainEntry { Ok(Self { name: value.name.clone(), implementation: value.middleware.clone(), + order: value.order, config: value.config.clone().unwrap_or_default(), on_error: OnError::parse(&value.on_error)?, }) @@ -297,10 +301,12 @@ fn validate_external_manifest( binding.id )); } - if binding.operation != HTTP_REQUEST_OPERATION || binding.phase != PRE_CREDENTIALS_PHASE { + if binding.operation != HTTP_REQUEST_OPERATION as i32 + || binding.phase != PRE_CREDENTIALS_PHASE as i32 + { return Err(miette!( - "middleware binding '{}' must support {HTTP_REQUEST_OPERATION}/{PRE_CREDENTIALS_PHASE}", - binding.id + "middleware binding '{}' must support HTTP_REQUEST/PRE_CREDENTIALS", + binding.id, )); } let advertised = usize::try_from(binding.max_body_bytes).map_err(|_| { @@ -523,6 +529,8 @@ impl ChainRunner { pub async fn describe_chain(&self, entries: &[ChainEntry]) -> Result> { let manifests = self.manifests().await?; + let mut entries = entries.to_vec(); + sort_chain_entries(&mut entries); entries .iter() .map(|entry| { @@ -810,6 +818,15 @@ impl ChainRunner { } } +/// Sort middleware using the policy-defined priority and a stable name tie-breaker. +pub fn sort_chain_entries(entries: &mut [ChainEntry]) { + entries.sort_by(|left, right| { + left.order + .cmp(&right.order) + .then_with(|| left.name.cmp(&right.name)) + }); +} + fn build_evaluation( entry: &DescribedChainEntry, binding: &MiddlewareBinding, @@ -820,7 +837,7 @@ fn build_evaluation( HttpRequestEvaluation { api_version: API_VERSION.into(), binding_id: binding.id.clone(), - phase: binding.phase.clone(), + phase: binding.phase, context: Some(RequestContext { request_id: input.request_id.clone(), sandbox_id: input.sandbox_id.clone(), @@ -912,6 +929,7 @@ mod tests { ChainEntry { name: name.into(), implementation: BUILTIN_SECRETS.into(), + order: 0, config: prost_types::Struct { fields: std::iter::once(( "secrets".into(), @@ -951,6 +969,10 @@ mod tests { let input = input("payload"); let evaluation = build_evaluation(entry, binding, &input, &BTreeMap::new(), b"payload"); + assert_eq!( + evaluation.phase, + SupervisorMiddlewarePhase::PreCredentials as i32 + ); assert!( evaluation .context @@ -995,11 +1017,32 @@ mod tests { assert_eq!(outcome.applied.len(), 2); } + #[tokio::test] + async fn describe_chain_sorts_by_order_then_name() { + let mut later = entry("later", OnError::FailClosed); + later.order = 20; + let mut beta = entry("beta", OnError::FailClosed); + beta.order = 10; + let mut alpha = entry("alpha", OnError::FailClosed); + alpha.order = 10; + + let described = ChainRunner::default() + .describe_chain(&[later, beta, alpha]) + .await + .expect("describe ordered chain"); + let names: Vec<_> = described + .iter() + .map(|entry| entry.entry.name.as_str()) + .collect(); + assert_eq!(names, vec!["alpha", "beta", "later"]); + } + #[tokio::test] async fn fail_open_allows_unavailable_middleware() { let unavailable = ChainEntry { name: "missing".into(), implementation: "third-party/missing".into(), + order: 0, config: prost_types::Struct::default(), on_error: OnError::FailOpen, }; @@ -1016,6 +1059,7 @@ mod tests { let unavailable = ChainEntry { name: "missing".into(), implementation: "third-party/missing".into(), + order: 0, config: prost_types::Struct::default(), on_error: OnError::FailClosed, }; @@ -1036,8 +1080,14 @@ mod tests { .into_inner(); assert_eq!(manifest.api_version, API_VERSION); assert_eq!(manifest.bindings[0].id, BUILTIN_SECRETS); - assert_eq!(manifest.bindings[0].operation, "HttpRequest"); - assert_eq!(manifest.bindings[0].phase, "pre_credentials"); + assert_eq!( + manifest.bindings[0].operation, + SupervisorMiddlewareOperation::HttpRequest as i32 + ); + assert_eq!( + manifest.bindings[0].phase, + SupervisorMiddlewarePhase::PreCredentials as i32 + ); assert_eq!(manifest.bindings[0].max_body_bytes, 256 * 1024); } @@ -1088,8 +1138,8 @@ mod tests { service_version: "test".into(), bindings: vec![MiddlewareBinding { id: self.binding_id.clone(), - operation: "HttpRequest".into(), - phase: "pre_credentials".into(), + operation: SupervisorMiddlewareOperation::HttpRequest as i32, + phase: SupervisorMiddlewarePhase::PreCredentials as i32, max_body_bytes: self.max_body_bytes, }], })) @@ -1190,6 +1240,7 @@ mod tests { let unresolved = ChainEntry { name: "missing".into(), implementation: "third-party/missing".into(), + order: 10, config: prost_types::Struct::default(), on_error: OnError::FailOpen, }; @@ -1214,6 +1265,7 @@ mod tests { let entry = ChainEntry { name: "external".into(), implementation: "example/redactor".into(), + order: 0, config: prost_types::Struct::default(), on_error: OnError::FailClosed, }; @@ -1229,7 +1281,7 @@ mod tests { .as_ref() .expect("described binding") .phase, - "pre_credentials" + SupervisorMiddlewarePhase::PreCredentials as i32 ); let outcome = runner @@ -1251,6 +1303,7 @@ mod tests { let external_entry = ChainEntry { name: "external".into(), implementation: "example/content-guard".into(), + order: 0, config: prost_types::Struct::default(), on_error: OnError::FailClosed, }; @@ -1284,8 +1337,8 @@ mod tests { service_version: "test".into(), bindings: vec![MiddlewareBinding { id: "example/content-guard".into(), - operation: HTTP_REQUEST_OPERATION.into(), - phase: PRE_CREDENTIALS_PHASE.into(), + operation: HTTP_REQUEST_OPERATION as i32, + phase: PRE_CREDENTIALS_PHASE as i32, max_body_bytes: 4096, }], }; @@ -1346,6 +1399,7 @@ mod tests { network_middlewares: vec![NetworkMiddlewareConfig { name: "guard".into(), middleware: "example/content-guard".into(), + order: 0, config: Some(prost_types::Struct::default()), on_error: "fail_closed".into(), endpoints: None, @@ -1367,6 +1421,7 @@ mod tests { &[ChainEntry { name: "guard".into(), implementation: "example/content-guard".into(), + order: 0, config: prost_types::Struct::default(), on_error: OnError::FailClosed, }], diff --git a/crates/openshell-supervisor-network/src/l7/relay.rs b/crates/openshell-supervisor-network/src/l7/relay.rs index 87ab2f39f..96e05f272 100644 --- a/crates/openshell-supervisor-network/src/l7/relay.rs +++ b/crates/openshell-supervisor-network/src/l7/relay.rs @@ -2801,6 +2801,7 @@ network_policies: let fail_open = ChainEntry { name: "m".into(), implementation: "openshell/secrets".into(), + order: 0, config: prost_types::Struct::default(), on_error: OnError::FailOpen, }; @@ -2845,12 +2846,14 @@ network_policies: let resolved = ChainEntry { name: "redact".into(), implementation: openshell_supervisor_middleware::BUILTIN_SECRETS.into(), + order: 0, config: prost_types::Struct::default(), on_error: OnError::FailClosed, }; let unresolved = ChainEntry { name: "missing".into(), implementation: "third-party/missing".into(), + order: 0, config: prost_types::Struct::default(), on_error: OnError::FailOpen, }; diff --git a/crates/openshell-supervisor-network/src/l7/rest.rs b/crates/openshell-supervisor-network/src/l7/rest.rs index 4f2d37f08..5f276a1d4 100644 --- a/crates/openshell-supervisor-network/src/l7/rest.rs +++ b/crates/openshell-supervisor-network/src/l7/rest.rs @@ -33,6 +33,7 @@ async fn max_middleware_body_bytes() -> usize { .describe_chain(&[openshell_supervisor_middleware::ChainEntry { name: "test".into(), implementation: openshell_supervisor_middleware::BUILTIN_SECRETS.into(), + order: 0, config: prost_types::Struct::default(), on_error: openshell_supervisor_middleware::OnError::FailClosed, }]) diff --git a/crates/openshell-supervisor-network/src/opa.rs b/crates/openshell-supervisor-network/src/opa.rs index c34ea9f61..6ee3b1daf 100644 --- a/crates/openshell-supervisor-network/src/opa.rs +++ b/crates/openshell-supervisor-network/src/opa.rs @@ -729,6 +729,7 @@ fn global_middleware_entries(configs: &[regorus::Value], host: &str) -> Result Result { Ok(ChainEntry { name, implementation, + order: get_field(value, "order") + .and_then(|value| match value { + regorus::Value::Number(number) => number.as_i64(), + _ => None, + }) + .and_then(|value| i32::try_from(value).ok()) + .unwrap_or_default(), config: get_field(value, "config") .map(regorus_value_to_struct) .unwrap_or_default(), @@ -1630,6 +1638,7 @@ fn proto_to_opa_data_json(proto: &ProtoSandboxPolicy, entrypoint_pid: u32) -> St let mut value = serde_json::json!({ "name": mw.name, "middleware": mw.middleware, + "order": mw.order, }); if let Some(config) = &mw.config { value["config"] = openshell_core::proto_struct::struct_to_json_value(config); @@ -6732,19 +6741,22 @@ network_policies: } #[test] - fn middleware_chain_uses_matching_selector_declaration_order() { + fn middleware_chain_uses_configured_order_and_name_tie_breaker() { let data = r#" network_middlewares: - name: global-redactor middleware: openshell/secrets + order: 20 endpoints: include: ["api.example.com"] - name: policy-redactor middleware: openshell/secrets + order: 10 endpoints: include: ["api.example.com"] - name: endpoint-redactor middleware: openshell/secrets + order: 10 endpoints: include: ["api.example.com"] network_policies: @@ -6775,7 +6787,7 @@ network_policies: let names: Vec<_> = chain.iter().map(|entry| entry.name.as_str()).collect(); assert_eq!( names, - vec!["global-redactor", "policy-redactor", "endpoint-redactor"] + vec!["endpoint-redactor", "policy-redactor", "global-redactor"] ); } diff --git a/docs/extensibility/supervisor-middleware.mdx b/docs/extensibility/supervisor-middleware.mdx index f2ea0c98a..80b7b9d43 100644 --- a/docs/extensibility/supervisor-middleware.mdx +++ b/docs/extensibility/supervisor-middleware.mdx @@ -18,7 +18,7 @@ For each inspected HTTP request, the supervisor: 1. Evaluates network and L7 policy. 2. Selects middleware whose host selectors match the admitted destination. 3. Buffers the request body using the smallest body limit in the selected chain. -4. Runs matching middleware in policy declaration order. +4. Runs matching middleware by ascending `order`, using the policy-local name to break ties. 5. Applies allowed transformations, injects provider credentials, and forwards the request. Middleware receives the request before credential injection. Operator-run services cannot inspect OpenShell-managed credentials. @@ -63,6 +63,7 @@ Add middleware configs to the top-level `network_middlewares` list: network_middlewares: - name: redact-secrets middleware: openshell/secrets + order: 10 config: secrets: redact on_error: fail_closed @@ -71,11 +72,11 @@ network_middlewares: exclude: ["trusted.example.com"] ``` -Each config has a policy-local `name`, a built-in or operator-provided binding ID in `middleware`, implementation-owned `config`, failure behavior, and host selectors. +Each config has a policy-local `name`, a built-in or operator-provided binding ID in `middleware`, an integer `order`, implementation-owned `config`, failure behavior, and host selectors. `include` selects destination hosts. `exclude` takes precedence and removes hosts from that selection. Matching is case-insensitive and uses the same exact-host and DNS glob behavior as network policy endpoints. -Matching configs run once each in top-level declaration order. Different config names may reference the same binding and run as separate stages. Config names must be unique. +Matching configs run once each by ascending `order`; lower values run first and policy-local names break ties. The default order is `0`. Different config names may reference the same binding and run as separate stages. Config names must be unique. See [Policy Schema](/reference/policy-schema#network-middleware) for the complete field reference. @@ -131,7 +132,7 @@ See [Logging](/observability/logging) for log access and [OCSF JSON Export](/obs ## Current Limitations - Middleware applies only to HTTP requests parsed by the supervisor. -- The supported operation and phase are `HttpRequest/pre_credentials`. +- The typed operation and phase are `HTTP_REQUEST/PRE_CREDENTIALS`. - Selection uses destination host include and exclude patterns. - Required middleware cannot cover `tls: skip` endpoints because OpenShell cannot inspect that traffic. - Operator-run services support plaintext `http://` and TLS `https://` endpoints. HTTPS certificates must chain to a CA in the platform trust store. diff --git a/docs/reference/policy-schema.mdx b/docs/reference/policy-schema.mdx index e36c535be..266f5d8ca 100644 --- a/docs/reference/policy-schema.mdx +++ b/docs/reference/policy-schema.mdx @@ -478,12 +478,13 @@ Identifies an executable that is permitted to use the associated endpoints. **Category:** Dynamic -An ordered list of middleware configs selected after network and L7 policy admit an HTTP request. Middleware selection is independent of the network policy entry that admitted the request. Every matching config runs once in list order before provider credential injection. +An ordered list of middleware configs selected after network and L7 policy admit an HTTP request. Middleware selection is independent of the network policy entry that admitted the request. Every matching config runs once by ascending `order`, with the policy-local name breaking ties, before provider credential injection. ```yaml showLineNumbers={false} network_middlewares: - name: redact-secrets middleware: openshell/secrets + order: 10 config: secrets: redact on_error: fail_closed @@ -496,6 +497,7 @@ network_middlewares: |---|---|---|---| | `name` | string | Yes | Policy-local config name. Names must be unique within the list. | | `middleware` | string | Yes | Built-in or operator-registered binding ID. `openshell/` is reserved for built-ins. | +| `order` | integer | No | Execution priority. Lower values run first; names break ties. Defaults to `0`. | | `config` | object | No | Implementation-owned configuration validated by the selected middleware. | | `on_error` | string | No | `fail_closed` denies the request when the stage fails; `fail_open` skips the failed stage. Defaults to `fail_closed`. | | `endpoints` | object | Yes | Host selector with required non-empty `include` and optional `exclude` lists. Exclusions take precedence. | diff --git a/docs/sandboxes/policies.mdx b/docs/sandboxes/policies.mdx index 1353fd640..a49e89c19 100644 --- a/docs/sandboxes/policies.mdx +++ b/docs/sandboxes/policies.mdx @@ -48,6 +48,7 @@ network_policies: network_middlewares: - name: redact-secrets middleware: openshell/secrets + order: 10 config: secrets: redact on_error: fail_closed @@ -68,7 +69,7 @@ Raw streams are connection-scoped and outside L7 live-reload guarantees. This in | `landlock` | Static | Configures Landlock LSM enforcement behavior. Set `compatibility` to `best_effort` (skip individual inaccessible paths while applying remaining rules) or `hard_requirement` (fail if any path is inaccessible or the required kernel ABI is unavailable). Refer to the [Policy Schema Reference](/reference/policy-schema#landlock) for the full behavior table. | | `process` | Static | Sets the OS-level identity for the agent process. `run_as_user` and `run_as_group` default to `sandbox`. Root (`root` or `0`) is rejected. The agent also runs with seccomp filters that block dangerous system calls. | | `network_policies` | Dynamic | Controls network access for ordinary outbound traffic from the sandbox. Each block has a name, a list of endpoints (host, port, protocol, and optional rules), and a list of binaries allowed to use those endpoints.
Every outbound connection except `https://inference.local` goes through the proxy, which queries the [policy engine](/about/how-it-works#core-components) with the destination and calling binary. A connection is allowed only when both match an entry in the same policy block.
For endpoints with `protocol: rest`, the proxy auto-detects TLS and terminates it so each HTTP request can be checked against that endpoint's `rules` (method and path). For endpoints with `protocol: websocket`, the proxy validates the RFC 6455 upgrade and evaluates `GET` rules for the handshake plus either `WEBSOCKET_TEXT` rules for raw client text messages or GraphQL operation rules for GraphQL-over-WebSocket messages. Set `websocket_credential_rewrite: true` only when a WebSocket or REST compatibility endpoint must keep placeholder credentials in sandbox-owned text frames and resolve them at the OpenShell relay boundary.
Endpoints without `protocol` allow the TCP stream through without inspecting payloads.
If no endpoint matches, the connection is denied. Configure managed inference separately through [Inference Routing](/sandboxes/inference-routing). | -| `network_middlewares` | Dynamic | Declares ordered HTTP request middleware configs. After network and L7 policy admit a request, OpenShell matches each config's host selectors independently and runs matching entries in declaration order before credential injection. | +| `network_middlewares` | Dynamic | Declares ordered HTTP request middleware configs. After network and L7 policy admit a request, OpenShell matches each config's host selectors independently and runs matching entries by ascending `order`, using the policy-local name to break ties, before credential injection. | ## Supervisor Middleware @@ -78,6 +79,7 @@ Supervisor middleware can inspect, deny, or replace admitted HTTP request bodies network_middlewares: - name: redact-secrets middleware: openshell/secrets + order: 10 config: secrets: redact on_error: fail_closed @@ -86,7 +88,7 @@ network_middlewares: exclude: ["trusted.example.com"] ``` -Matching entries run once each in top-level declaration order. Config names must be unique. Different config names may use the same implementation and run as distinct stages. `exclude` takes precedence over `include`. +Matching entries run once each by ascending `order`; lower values run first and policy-local names break ties. The default order is `0`. Config names must be unique. Different config names may use the same implementation and run as distinct stages. `exclude` takes precedence over `include`. `openshell/secrets` is built into the supervisor. Operator-provided binding IDs must be registered before a policy can reference them. The gateway validates implementation-owned config before accepting the policy. diff --git a/proto/middleware.proto b/proto/middleware.proto index 7a4eb28df..421acd3d5 100644 --- a/proto/middleware.proto +++ b/proto/middleware.proto @@ -38,10 +38,10 @@ message MiddlewareManifest { message MiddlewareBinding { // Stable binding id used by policy configuration and audit logs. string id = 1; - // Supported operation name. V1 supports "HttpRequest". - string operation = 2; - // Supported evaluation phase. V1 supports "pre_credentials". - string phase = 3; + // Supported operation. V1 supports HTTP_REQUEST. + SupervisorMiddlewareOperation operation = 2; + // Supported evaluation phase. V1 supports PRE_CREDENTIALS. + SupervisorMiddlewarePhase phase = 3; // Maximum request or replacement body this binding can process. uint64 max_body_bytes = 4; } @@ -71,7 +71,7 @@ message HttpRequestEvaluation { // Manifest binding id selected for this evaluation. string binding_id = 2; // Evaluation phase selected for this request. - string phase = 3; + SupervisorMiddlewarePhase phase = 3; // Sandbox and request identity available to the supervisor. RequestContext context = 4; // Validated service-specific policy configuration. @@ -84,6 +84,18 @@ message HttpRequestEvaluation { bytes body = 8; } +// Supervisor operation selected for middleware evaluation. +enum SupervisorMiddlewareOperation { + SUPERVISOR_MIDDLEWARE_OPERATION_UNSPECIFIED = 0; + SUPERVISOR_MIDDLEWARE_OPERATION_HTTP_REQUEST = 1; +} + +// Ordered phase within a supervisor operation. +enum SupervisorMiddlewarePhase { + SUPERVISOR_MIDDLEWARE_PHASE_UNSPECIFIED = 0; + SUPERVISOR_MIDDLEWARE_PHASE_PRE_CREDENTIALS = 1; +} + // RequestContext identifies the sandbox request being evaluated. message RequestContext { // Request id used to correlate middleware and supervisor logs. diff --git a/proto/sandbox.proto b/proto/sandbox.proto index c4018573d..a8a27349c 100644 --- a/proto/sandbox.proto +++ b/proto/sandbox.proto @@ -77,6 +77,8 @@ message NetworkMiddlewareConfig { string on_error = 4; // Host selector controlling which admitted destinations use this config. MiddlewareEndpointSelector endpoints = 5; + // Deterministic execution order. Lower values run first; names break ties. + int32 order = 6; } // Host selector controlling which admitted destinations use a middleware config.