diff --git a/src/attested_get.rs b/src/attested_get.rs index 7fc40b9..a2a0ad8 100644 --- a/src/attested_get.rs +++ b/src/attested_get.rs @@ -86,6 +86,7 @@ mod tests { }, }), Some("127.0.0.1:0"), + None, target_addr.to_string(), AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), AttestationVerifier::expect_none(), diff --git a/src/file_server.rs b/src/file_server.rs index 4c5c9bb..abfef0b 100644 --- a/src/file_server.rs +++ b/src/file_server.rs @@ -7,16 +7,42 @@ use std::{net::SocketAddr, path::PathBuf}; use tokio::net::ToSocketAddrs; use tower_http::services::ServeDir; +/// Configuration for serving a local directory over the attested proxy +pub struct AttestedFileServerConfig { + /// Filesystem path to expose over HTTP + pub path_to_serve: PathBuf, + /// TLS certificate and key for the optional outer listener + pub outer_cert_and_key: Option, + /// Bind address for the optional outer nested-TLS listener + pub outer_listen_addr: Option, + /// Bind address for the optional inner attested-TLS listener + pub inner_listen_addr: Option, + /// Certificate name to embed in the inner attested certificate + pub inner_certificate_name: Option, + /// Attestation generator used by the proxy server + pub attestation_generator: AttestationGenerator, + /// Attestation verifier used for the remote peer + pub attestation_verifier: AttestationVerifier, + /// Whether inner TLS should require client authentication + pub client_auth: bool, +} + /// Setup a static file server serving the given directory, and a proxy server targetting it -pub async fn attested_file_server( - path_to_serve: PathBuf, - outer_cert_and_key: Option, - outer_listen_addr: Option, - inner_listen_addr: Option, - attestation_generator: AttestationGenerator, - attestation_verifier: AttestationVerifier, - client_auth: bool, -) -> Result<(), ProxyError> { +pub async fn attested_file_server(config: AttestedFileServerConfig) -> Result<(), ProxyError> +where + A: ToSocketAddrs, +{ + let AttestedFileServerConfig { + path_to_serve, + outer_cert_and_key, + outer_listen_addr, + inner_listen_addr, + inner_certificate_name, + attestation_generator, + attestation_verifier, + client_auth, + } = config; + let target_addr = static_file_server(path_to_serve).await?; let outer_session = match (outer_cert_and_key, outer_listen_addr) { (Some(cert_and_key), Some(listen_addr)) => Some(OuterTlsConfig { @@ -32,6 +58,7 @@ pub async fn attested_file_server( let server = ProxyServer::new( outer_session, inner_listen_addr, + inner_certificate_name, target_addr.to_string(), attestation_generator, attestation_verifier, @@ -121,6 +148,7 @@ mod tests { }, }), Some("127.0.0.1:0"), + None, target_addr.to_string(), AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), AttestationVerifier::expect_none(), diff --git a/src/http_version.rs b/src/http_version.rs index 901df66..157a6d9 100644 --- a/src/http_version.rs +++ b/src/http_version.rs @@ -9,6 +9,7 @@ pub const ALPN_HTTP11: &[u8] = b"http/1.1"; type ProxyClientTlsStream = tokio_rustls::client::TlsStream>; +type ProxyClientInnerOnlyTlsStream = tokio_rustls::client::TlsStream; /// Supported HTTP versions #[derive(Debug)] @@ -60,12 +61,21 @@ type Http2Sender = hyper::client::conn::http2::SendRequest, hyper::body::Incoming>; +type Http1InnerOnlyConnection = hyper::client::conn::http1::Connection< + TokioIo, + hyper::body::Incoming, +>; type Http2Connection = hyper::client::conn::http2::Connection< TokioIo, hyper::body::Incoming, crate::TokioExecutor, >; +type Http2InnerOnlyConnection = hyper::client::conn::http2::Connection< + TokioIo, + hyper::body::Incoming, + crate::TokioExecutor, +>; /// A protocol version agnostic HTTP sender pub enum HttpSender { @@ -102,7 +112,9 @@ pin_project_lite::pin_project! { #[project = HttpConnectionProj] pub enum HttpConnection { Http1 { #[pin] inner: Http1Connection }, + Http1InnerOnly { #[pin] inner: Http1InnerOnlyConnection }, Http2 { #[pin] inner: Http2Connection }, + Http2InnerOnly { #[pin] inner: Http2InnerOnlyConnection }, } } @@ -118,13 +130,27 @@ impl From for HttpConnection { } } +impl From for HttpConnection { + fn from(inner: Http1InnerOnlyConnection) -> Self { + Self::Http1InnerOnly { inner } + } +} + +impl From for HttpConnection { + fn from(inner: Http2InnerOnlyConnection) -> Self { + Self::Http2InnerOnly { inner } + } +} + impl Future for HttpConnection { type Output = Result<(), hyper::Error>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { match self.project() { HttpConnectionProj::Http1 { inner } => inner.poll(cx), + HttpConnectionProj::Http1InnerOnly { inner } => inner.poll(cx), HttpConnectionProj::Http2 { inner } => inner.poll(cx), + HttpConnectionProj::Http2InnerOnly { inner } => inner.poll(cx), } } } diff --git a/src/lib.rs b/src/lib.rs index dd9c17f..2a11b8f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -19,19 +19,20 @@ use http::{HeaderMap, HeaderName, HeaderValue}; use http_body_util::{BodyExt, combinators::BoxBody}; use hyper::{Response, service::service_fn}; use hyper_util::rt::TokioIo; -use nested_tls::server::NestingTlsStream; -use nested_tls::{client::NestingTlsConnector, server::NestingTlsAcceptor}; +use nested_tls::{ + client::NestingTlsConnector, server::NestingTlsAcceptor, server::NestingTlsStream, +}; use std::{net::SocketAddr, num::TryFromIntError, sync::Arc, time::Duration}; use thiserror::Error; use tokio::io::{self, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream, ToSocketAddrs}; use tokio::sync::{mpsc, oneshot}; -use tokio_rustls::TlsAcceptor; use tokio_rustls::rustls::server::{VerifierBuilderError, WebPkiClientVerifier}; use tokio_rustls::rustls::{ self, ClientConfig, RootCertStore, ServerConfig, pki_types::{CertificateDer, PrivateKeyDer, ServerName}, }; +use tokio_rustls::{TlsAcceptor, TlsConnector}; use tracing::{debug, error, warn}; use crate::http_version::{ALPN_H2, ALPN_HTTP11, HttpConnection, HttpSender, HttpVersion}; @@ -61,6 +62,14 @@ type RequestWithResponseSender = ( type OuterProxySession = (Arc, NestingTlsAcceptor); type InnerProxySession = (Arc, TlsAcceptor); +#[derive(Clone)] +enum ProxyTlsConnector { + Nested(NestingTlsConnector), + InnerOnly(TlsConnector), +} + +impl ProxyTlsConnector {} + /// TLS Credentials pub struct TlsCertAndKey { /// Der-encoded TLS certificate chain @@ -227,6 +236,7 @@ impl ProxyServer { pub async fn new( outer_session: Option>, inner_local: Option, + inner_certificate_name: Option, target: String, attestation_generator: AttestationGenerator, attestation_verifier: AttestationVerifier, @@ -243,7 +253,8 @@ impl ProxyServer { let certificate_name = outer_session .as_ref() .map(OuterTlsConfig::certificate_name) - .transpose()?; + .transpose()? + .or(inner_certificate_name); let inner_server_config = Arc::new( build_inner_server_config( attestation_generator, @@ -709,13 +720,76 @@ impl ProxyClient { let nesting_tls_connector = NestingTlsConnector::new(Arc::new(outer_client_config), Arc::new(inner_client_config)); - Self::new_with_inner(address, nesting_tls_connector, &target_name).await + Self::new_with_connector( + address, + ProxyTlsConnector::Nested(nesting_tls_connector), + &target_name, + ) + .await + } + + /// Start a proxy client which connects directly to the server's inner attested TLS listener. + pub async fn new_inner_only( + cert_and_key: Option, + address: impl ToSocketAddrs, + server_name: String, + attestation_generator: AttestationGenerator, + attestation_verifier: AttestationVerifier, + ) -> Result { + if cert_and_key.is_some() { + return Err(ProxyError::InnerOnlyClientAuthUnsupported); + } + + Self::new_inner_only_with_tls_config( + address, + server_name, + attestation_generator, + attestation_verifier, + None, + ) + .await + } + + /// Create a new inner-only proxy client with given TLS configuration. + pub async fn new_inner_only_with_tls_config( + address: impl ToSocketAddrs, + target_name: String, + attestation_generator: AttestationGenerator, + attestation_verifier: AttestationVerifier, + cert_chain: Option>>, + ) -> Result { + let attested_cert_verifier = AttestedCertificateVerifier::new(None, attestation_verifier)?; + + let mut inner_client_config = if let Some(cert_chain) = cert_chain.as_ref() { + let inner_cert_resolver = build_attested_cert_resolver( + attestation_generator, + certificate_identity_from_chain(cert_chain)?, + ) + .await?; + ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) + .dangerous() + .with_custom_certificate_verifier(Arc::new(attested_cert_verifier)) + .with_client_cert_resolver(Arc::new(inner_cert_resolver)) + } else { + ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) + .dangerous() + .with_custom_certificate_verifier(Arc::new(attested_cert_verifier)) + .with_no_client_auth() + }; + ensure_proxy_alpn_protocols(&mut inner_client_config.alpn_protocols); + + Self::new_with_connector( + address, + ProxyTlsConnector::InnerOnly(TlsConnector::from(Arc::new(inner_client_config))), + &target_name, + ) + .await } - /// Create a new proxy client with given [AttestedTlsClient] - pub async fn new_with_inner( + /// Create a new proxy client with a configured TLS connector. + async fn new_with_connector( address: impl ToSocketAddrs, - nesting_tls_connector: NestingTlsConnector, + tls_connector: ProxyTlsConnector, target_name: &str, ) -> Result { let listener = TcpListener::bind(address).await?; @@ -740,7 +814,7 @@ impl ProxyClient { 'reconnect: loop { let (mut sender, conn, attestation) = // Connect to the proxy server and provide / verify attestation - match Self::setup_connection_with_backoff(&target, &nesting_tls_connector, first) + match Self::setup_connection_with_backoff(&target, &tls_connector, first) .await { Ok(output) => { @@ -910,14 +984,14 @@ impl ProxyClient { // If it fails retry with a backoff (indefinately) async fn setup_connection_with_backoff( target: &str, - nesting_tls_connector: &NestingTlsConnector, + tls_connector: &ProxyTlsConnector, should_bail: bool, ) -> Result<(HttpSender, HttpConnection, AttestationExchangeMessage), ProxyError> { let mut delay = Duration::from_secs(1); let max_delay = Duration::from_secs(SERVER_RECONNECT_MAX_BACKOFF_SECS); loop { - match Self::setup_connection(nesting_tls_connector, target).await { + match Self::setup_connection(tls_connector, target).await { Ok(output) => { return Ok(output); } @@ -939,54 +1013,88 @@ impl ProxyClient { /// Connect to the proxy-server, do TLS handshake and remote attestation async fn setup_connection( - nesting_tls_connector: &NestingTlsConnector, + tls_connector: &ProxyTlsConnector, target: &str, ) -> Result<(HttpSender, HttpConnection, AttestationExchangeMessage), ProxyError> { let outbound_stream = tokio::net::TcpStream::connect(target).await?; let domain = server_name_from_host(target)?; - let tls_stream = nesting_tls_connector - .connect(domain, outbound_stream) - .await?; - - debug!("[proxy-client] Connected to proxy server"); - - let attestation = { - let (_io, server_connection) = tls_stream.get_ref(); - - let remote_cert_chain = server_connection - .peer_certificates() - .ok_or(ProxyError::NoCertificate)?; - - AttestedCertificateVerifier::extract_custom_attestation_from_cert( - remote_cert_chain.first().ok_or(ProxyError::NoCertificate)?, - )? - }; - - // The attestation exchange is now complete - setup an HTTP client - let http_version = HttpVersion::from_negotiated_protocol_client(&tls_stream); + match tls_connector { + ProxyTlsConnector::Nested(connector) => { + let tls_stream = connector.connect(domain, outbound_stream).await?; + debug!("[proxy-client] Connected to proxy server"); + + let attestation = Self::extract_peer_attestation(&tls_stream)?; + let http_version = HttpVersion::from_negotiated_protocol_client(&tls_stream); + + let outbound_io = TokioIo::new(tls_stream); + let (sender, conn) = match http_version { + HttpVersion::Http2 => { + let (sender, conn) = + hyper::client::conn::http2::Builder::new(TokioExecutor) + .timer(hyper_util::rt::tokio::TokioTimer::new()) + .keep_alive_interval(Some(Duration::from_secs(KEEP_ALIVE_INTERVAL))) + .keep_alive_timeout(Duration::from_secs(KEEP_ALIVE_TIMEOUT)) + .keep_alive_while_idle(true) + .handshake::<_, hyper::body::Incoming>(outbound_io) + .await?; + (sender.into(), conn.into()) + } + HttpVersion::Http1 => { + let (sender, conn) = hyper::client::conn::http1::Builder::new() + .handshake::<_, hyper::body::Incoming>(outbound_io) + .await?; + (sender.into(), conn.into()) + } + }; - let outbound_io = TokioIo::new(tls_stream); - let (sender, conn) = match http_version { - HttpVersion::Http2 => { - let (sender, conn) = hyper::client::conn::http2::Builder::new(TokioExecutor) - .timer(hyper_util::rt::tokio::TokioTimer::new()) - .keep_alive_interval(Some(Duration::from_secs(KEEP_ALIVE_INTERVAL))) - .keep_alive_timeout(Duration::from_secs(KEEP_ALIVE_TIMEOUT)) - .keep_alive_while_idle(true) - .handshake::<_, hyper::body::Incoming>(outbound_io) - .await?; - (sender.into(), conn.into()) + Ok((sender, conn, attestation)) } - HttpVersion::Http1 => { - let (sender, conn) = hyper::client::conn::http1::Builder::new() - .handshake::<_, hyper::body::Incoming>(outbound_io) - .await?; - (sender.into(), conn.into()) + ProxyTlsConnector::InnerOnly(connector) => { + let tls_stream = connector.connect(domain, outbound_stream).await?; + debug!("[proxy-client] Connected to proxy server"); + + let attestation = Self::extract_peer_attestation(&tls_stream)?; + let http_version = HttpVersion::from_negotiated_protocol_client(&tls_stream); + + let outbound_io = TokioIo::new(tls_stream); + let (sender, conn) = match http_version { + HttpVersion::Http2 => { + let (sender, conn) = + hyper::client::conn::http2::Builder::new(TokioExecutor) + .timer(hyper_util::rt::tokio::TokioTimer::new()) + .keep_alive_interval(Some(Duration::from_secs(KEEP_ALIVE_INTERVAL))) + .keep_alive_timeout(Duration::from_secs(KEEP_ALIVE_TIMEOUT)) + .keep_alive_while_idle(true) + .handshake::<_, hyper::body::Incoming>(outbound_io) + .await?; + (sender.into(), conn.into()) + } + HttpVersion::Http1 => { + let (sender, conn) = hyper::client::conn::http1::Builder::new() + .handshake::<_, hyper::body::Incoming>(outbound_io) + .await?; + (sender.into(), conn.into()) + } + }; + + Ok((sender, conn, attestation)) } - }; + } + } + + fn extract_peer_attestation( + tls_stream: &tokio_rustls::client::TlsStream, + ) -> Result { + let (_io, server_connection) = tls_stream.get_ref(); + let remote_cert_chain = server_connection + .peer_certificates() + .ok_or(ProxyError::NoCertificate)?; - Ok((sender, conn, attestation)) + AttestedCertificateVerifier::extract_custom_attestation_from_cert( + remote_cert_chain.first().ok_or(ProxyError::NoCertificate)?, + ) + .map_err(ProxyError::from) } // Handle a request from the source client to the proxy server @@ -1073,6 +1181,8 @@ pub enum ProxyError { MpscSend, #[error("Client auth must be configured on both the inner and outer TLS sessions")] ClientAuthMisconfigured, + #[error("Inner-session-only mode does not support user-supplied TLS client certificates")] + InnerOnlyClientAuthUnsupported, #[error("At least one server listener must be configured")] NoListenersConfigured, } @@ -1306,6 +1416,7 @@ mod tests { let result = ProxyServer::new( None::>, None::<&str>, + None, "127.0.0.1:1".to_string(), AttestationGenerator::with_no_attestation(), AttestationVerifier::expect_none(), @@ -1332,6 +1443,7 @@ mod tests { tls: OuterTlsMode::CertAndKey(tls_cert_and_key), }), Some("127.0.0.1:0"), + None, target_addr.to_string(), AttestationGenerator::with_no_attestation(), AttestationVerifier::expect_none(), @@ -1348,6 +1460,7 @@ mod tests { let inner_only_server = ProxyServer::new( None::>, Some("127.0.0.1:0"), + None, target_addr.to_string(), AttestationGenerator::with_no_attestation(), AttestationVerifier::expect_none(), @@ -1369,6 +1482,7 @@ mod tests { let proxy_server = ProxyServer::new( None::>, Some("127.0.0.1:0"), + None, target_addr.to_string(), AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), AttestationVerifier::expect_none(), @@ -1408,6 +1522,184 @@ mod tests { tls_stream.shutdown().await.unwrap(); } + #[tokio::test(flavor = "multi_thread")] + async fn inner_only_client_with_server_attestation() { + let target_addr = example_http_service().await; + + let proxy_server = ProxyServer::new( + None::>, + Some("127.0.0.1:0"), + None, + target_addr.to_string(), + AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), + AttestationVerifier::expect_none(), + false, + ) + .await + .unwrap(); + + let proxy_addr = proxy_server.inner_local_addr().unwrap().unwrap(); + + tokio::spawn(async move { + proxy_server.accept().await.unwrap(); + }); + + let proxy_client = ProxyClient::new_inner_only_with_tls_config( + "127.0.0.1:0", + format!("localhost:{}", proxy_addr.port()), + AttestationGenerator::with_no_attestation(), + AttestationVerifier::mock(), + None, + ) + .await + .unwrap(); + + let proxy_client_addr = proxy_client.local_addr().unwrap(); + + tokio::spawn(async move { + proxy_client.accept().await.unwrap(); + }); + + let res = reqwest::get(format!("http://{}", proxy_client_addr)) + .await + .unwrap(); + + assert_attestation_type_header(res.headers(), "dcap-tdx"); + assert_mock_measurements_header(res.headers()); + assert_eq!(res.text().await.unwrap(), "No measurements"); + } + + #[tokio::test(flavor = "multi_thread")] + async fn inner_only_client_rejects_user_supplied_tls_client_cert() { + let (cert_chain, private_key) = generate_certificate_chain_for_host("localhost"); + let err = ProxyClient::new_inner_only( + Some(TlsCertAndKey { + cert_chain, + key: private_key, + }), + "127.0.0.1:0", + "localhost:443".to_string(), + AttestationGenerator::with_no_attestation(), + AttestationVerifier::expect_none(), + ) + .await + .unwrap_err() + .to_string(); + + assert!(err.contains("Inner-session-only mode")); + } + + #[tokio::test(flavor = "multi_thread")] + async fn inner_only_client_supports_mutual_attestation() { + let target_addr = example_http_service().await; + let (client_cert_chain, _client_private_key) = + generate_certificate_chain_for_host("localhost"); + + let proxy_server = ProxyServer::new( + None::>, + Some("127.0.0.1:0"), + None, + target_addr.to_string(), + AttestationGenerator::with_no_attestation(), + AttestationVerifier::mock(), + true, + ) + .await + .unwrap(); + + let proxy_addr = proxy_server.inner_local_addr().unwrap().unwrap(); + + tokio::spawn(async move { + proxy_server.accept().await.unwrap(); + }); + + let proxy_client = ProxyClient::new_inner_only_with_tls_config( + "127.0.0.1:0", + format!("localhost:{}", proxy_addr.port()), + AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), + AttestationVerifier::expect_none(), + Some(client_cert_chain), + ) + .await + .unwrap(); + + let proxy_client_addr = proxy_client.local_addr().unwrap(); + + tokio::spawn(async move { + proxy_client.accept().await.unwrap(); + }); + + let res = reqwest::get(format!("http://{}", proxy_client_addr)) + .await + .unwrap(); + + assert_attestation_type_header(res.headers(), "none"); + assert_no_measurements_header(res.headers()); + assert_mock_measurements(&res.text().await.unwrap()); + } + + #[tokio::test(flavor = "multi_thread")] + async fn inner_only_server_uses_configured_certificate_name() { + use tokio_rustls::rustls::client::ResolvesClientCert; + + let _ = rustls::crypto::aws_lc_rs::default_provider().install_default(); + + let resolver = build_attested_cert_resolver( + AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), + "custom.inner.name".to_string(), + ) + .await + .unwrap(); + + let certified_key = resolver.resolve(&[], &[]).unwrap(); + let cert_chain = &certified_key.cert; + + assert_eq!( + hostname_from_cert(cert_chain.first().unwrap()).unwrap(), + "custom.inner.name" + ); + } + + #[tokio::test(flavor = "multi_thread")] + async fn nested_client_fails_against_inner_only_listener() { + let target_addr = example_http_service().await; + + let proxy_server = ProxyServer::new( + None::>, + Some("127.0.0.1:0"), + None, + target_addr.to_string(), + AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), + AttestationVerifier::expect_none(), + false, + ) + .await + .unwrap(); + + let proxy_addr = proxy_server.inner_local_addr().unwrap().unwrap(); + + tokio::spawn(async move { + proxy_server.accept().await.unwrap(); + }); + + let (cert_chain, private_key) = generate_certificate_chain_for_host("localhost"); + let (_server_config, client_config) = generate_tls_config(cert_chain, private_key); + + let err = ProxyClient::new_with_tls_config( + client_config, + "127.0.0.1:0", + format!("localhost:{}", proxy_addr.port()), + AttestationGenerator::with_no_attestation(), + AttestationVerifier::mock(), + None, + ) + .await + .unwrap_err() + .to_string(); + + assert!(!err.is_empty()); + } + #[tokio::test(flavor = "multi_thread")] async fn http_proxy_negotiates_http2_by_default() { let target_addr = example_http_service().await; @@ -1425,6 +1717,7 @@ mod tests { }, }), Some("127.0.0.1:0"), + None, target_addr.to_string(), AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), AttestationVerifier::expect_none(), @@ -1452,7 +1745,7 @@ mod tests { NestingTlsConnector::new(Arc::new(outer_client_config), Arc::new(inner_client_config)); let (sender, conn, _attestation) = ProxyClient::setup_connection( - &nesting_tls_connector, + &ProxyTlsConnector::Nested(nesting_tls_connector), &format!("localhost:{}", proxy_addr.port()), ) .await @@ -1480,6 +1773,7 @@ mod tests { }, }), Some("127.0.0.1:0"), + None, target_addr.to_string(), AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), AttestationVerifier::expect_none(), @@ -1551,6 +1845,7 @@ mod tests { }, }), Some("127.0.0.1:0"), + None, target_addr.to_string(), AttestationGenerator::with_no_attestation(), AttestationVerifier::mock(), @@ -1612,6 +1907,7 @@ mod tests { }, }), Some("127.0.0.1:0"), + None, target_addr.to_string(), AttestationGenerator::with_no_attestation(), AttestationVerifier::mock(), @@ -1675,6 +1971,7 @@ mod tests { }, }), Some("127.0.0.1:0"), + None, target_addr.to_string(), AttestationGenerator::with_no_attestation(), AttestationVerifier::mock(), @@ -1749,6 +2046,7 @@ mod tests { }, }), Some("127.0.0.1:0"), + None, target_addr.to_string(), AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), AttestationVerifier::mock(), @@ -1813,6 +2111,7 @@ mod tests { }, }), Some("127.0.0.1:0"), + None, target_addr.to_string(), AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), AttestationVerifier::expect_none(), @@ -1861,6 +2160,7 @@ mod tests { }, }), Some("127.0.0.1:0"), + None, target_addr.to_string(), AttestationGenerator::with_no_attestation(), AttestationVerifier::expect_none(), @@ -1907,6 +2207,7 @@ mod tests { }, }), Some("127.0.0.1:0"), + None, target_addr.to_string(), AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), AttestationVerifier::expect_none(), @@ -1979,6 +2280,7 @@ mod tests { }, }), Some("127.0.0.1:0"), + None, target_addr.to_string(), AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), AttestationVerifier::expect_none(), @@ -2099,6 +2401,7 @@ mod tests { }, }), Some("127.0.0.1:0"), + None, target_addr.to_string(), AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), AttestationVerifier::expect_none(), @@ -2186,6 +2489,7 @@ mod tests { }, }), Some("127.0.0.1:0"), + None, target_addr.to_string(), AttestationGenerator::with_no_attestation(), AttestationVerifier::expect_none(), @@ -2246,6 +2550,7 @@ mod tests { }, }), Some("127.0.0.1:0"), + None, target_addr.to_string(), AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), AttestationVerifier::expect_none(), diff --git a/src/main.rs b/src/main.rs index f4be61c..f02ab0a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -9,8 +9,10 @@ use tracing::level_filters::LevelFilter; use attested_tls_proxy::{ AttestationGenerator, OuterTlsConfig, OuterTlsMode, ProxyClient, ProxyServer, TlsCertAndKey, - attested_get::attested_get, file_server::attested_file_server, get_inner_tls_cert, - health_check, normalize_pem::normalize_private_key_pem_to_pkcs8, + attested_get::attested_get, + file_server::{AttestedFileServerConfig, attested_file_server}, + get_inner_tls_cert, health_check, + normalize_pem::normalize_private_key_pem_to_pkcs8, }; const GIT_REV: &str = match option_env!("GIT_REV") { @@ -53,16 +55,19 @@ enum CliCommand { /// Socket address to listen on #[arg(short, long, default_value = "0.0.0.0:0", env = "LISTEN_ADDR")] listen_addr: SocketAddr, + /// Connect directly to the server's inner attested TLS listener instead of nested TLS + #[arg(long)] + inner_session_only: bool, /// The hostname:port or ip:port of the proxy server (port defaults to 443) target_addr: String, /// Type of attestation to present (dafaults to 'auto' for automatic detection) /// If other than None, a TLS key and certicate must also be given #[arg(long, env = "CLIENT_ATTESTATION_TYPE")] client_attestation_type: Option, - /// The path to a PEM encoded private key for client authentication + /// The path to a PEM encoded private key for client authentication in nested-TLS mode #[arg(long, env = "TLS_PRIVATE_KEY_PATH")] tls_private_key_path: Option, - /// The path to a PEM encoded certificate chain for client authentication + /// The path to a PEM encoded certificate chain for client authentication in nested-TLS mode #[arg(long, env = "TLS_CERTIFICATE_PATH")] tls_certificate_path: Option, /// Additional CA certificate to verify against (PEM) Defaults to no additional TLS certs. @@ -84,6 +89,9 @@ enum CliCommand { /// Socket address to listen on for the inner-only attested TLS listener #[arg(long)] inner_listen_addr: Option, + /// DNS name to embed into the inner attested certificate when no outer listener is used + #[arg(long)] + inner_certificate_name: Option, /// The hostname:port or ip:port of the target service to forward traffic to target_addr: String, /// Type of attestation to present (dafaults to 'auto' for automatic detection) @@ -129,6 +137,9 @@ enum CliCommand { /// Socket address to listen on for the inner-only attested TLS listener #[arg(long)] inner_listen_addr: Option, + /// DNS name to embed into the inner attested certificate when no outer listener is used + #[arg(long)] + inner_certificate_name: Option, /// Type of attestation to present (dafaults to none) /// This configures the inner attested TLS listener and does not require outer TLS certs. #[arg(long, env = "SERVER_ATTESTATION_TYPE")] @@ -225,6 +236,7 @@ async fn main() -> anyhow::Result<()> { match cli.command { CliCommand::Client { listen_addr, + inner_session_only, target_addr, client_attestation_type, tls_private_key_path, @@ -242,6 +254,13 @@ async fn main() -> anyhow::Result<()> { health_check::server(listen_addr_healthcheck).await?; } + validate_client_args( + inner_session_only, + tls_private_key_path.as_ref(), + tls_certificate_path.as_ref(), + tls_ca_certificate.as_ref(), + )?; + let tls_cert_and_chain = if let Some(private_key) = tls_private_key_path { Some(load_tls_cert_and_key( tls_certificate_path @@ -270,15 +289,26 @@ async fn main() -> anyhow::Result<()> { AttestationGenerator::new_with_detection(client_attestation_type, dev_dummy_dcap) .await?; - let client = ProxyClient::new( - tls_cert_and_chain, - listen_addr, - target_addr, - client_attestation_generator, - attestation_verifier, - remote_tls_cert, - ) - .await?; + let client = if inner_session_only { + ProxyClient::new_inner_only( + tls_cert_and_chain, + listen_addr, + target_addr, + client_attestation_generator, + attestation_verifier, + ) + .await? + } else { + ProxyClient::new( + tls_cert_and_chain, + listen_addr, + target_addr, + client_attestation_generator, + attestation_verifier, + remote_tls_cert, + ) + .await? + }; loop { if let Err(err) = client.accept().await { @@ -289,6 +319,7 @@ async fn main() -> anyhow::Result<()> { CliCommand::Server { outer_listen_addr, inner_listen_addr, + inner_certificate_name, target_addr, tls_private_key_path, tls_certificate_path, @@ -321,6 +352,7 @@ async fn main() -> anyhow::Result<()> { tls: OuterTlsMode::CertAndKey(cert_and_key), }), inner_listen_addr, + inner_certificate_name, target_addr, local_attestation_generator, attestation_verifier, @@ -367,6 +399,7 @@ async fn main() -> anyhow::Result<()> { path_to_serve, outer_listen_addr, inner_listen_addr, + inner_certificate_name, server_attestation_type, tls_private_key_path, tls_certificate_path, @@ -387,15 +420,16 @@ async fn main() -> anyhow::Result<()> { let attestation_generator = AttestationGenerator::new(server_attestation_type, dev_dummy_dcap)?; - attested_file_server( + attested_file_server(AttestedFileServerConfig { path_to_serve, - tls_cert_and_chain, + outer_cert_and_key: tls_cert_and_chain, outer_listen_addr, inner_listen_addr, + inner_certificate_name, attestation_generator, attestation_verifier, - false, - ) + client_auth: false, + }) .await?; } CliCommand::AttestedGet { @@ -475,6 +509,27 @@ fn validate_listener_args( Ok(()) } +fn validate_client_args( + inner_session_only: bool, + tls_private_key_path: Option<&PathBuf>, + tls_certificate_path: Option<&PathBuf>, + tls_ca_certificate: Option<&PathBuf>, +) -> anyhow::Result<()> { + if inner_session_only && tls_ca_certificate.is_some() { + return Err(anyhow!( + "--tls-ca-certificate cannot be used with --inner-session-only" + )); + } + + if inner_session_only && (tls_private_key_path.is_some() || tls_certificate_path.is_some()) { + return Err(anyhow!( + "--tls-private-key-path and --tls-certificate-path are not supported with --inner-session-only" + )); + } + + Ok(()) +} + /// Load TLS details from storage fn load_tls_cert_and_key( cert_chain: PathBuf, @@ -508,3 +563,27 @@ fn certs_to_pem_string(certs: &[CertificateDer<'_>]) -> Result