diff --git a/Cargo.toml b/Cargo.toml index 8313db4..a389c0c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,7 +4,7 @@ resolver = "2" [workspace.package] version = "0.16.1" -edition = "2021" +edition = "2024" publish = false authors = ["FastEdge Development Team"] diff --git a/crates/http-backend/src/lib.rs b/crates/http-backend/src/lib.rs index 3af701d..f14ed10 100644 --- a/crates/http-backend/src/lib.rs +++ b/crates/http-backend/src/lib.rs @@ -1,6 +1,6 @@ pub mod stats; -use smol_str::SmolStr; +use smol_str::{SmolStr, ToSmolStr}; use std::fmt::Debug; use std::future::Future; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; @@ -9,13 +9,13 @@ use std::sync::Arc; use std::task::{Context, Poll}; use std::time::Duration; -use anyhow::{anyhow, Error, Result}; -use http::{header, uri::Scheme, HeaderMap, HeaderName, Uri}; +use anyhow::{Error, Result, anyhow}; +use http::{HeaderMap, HeaderName, Uri, header, uri::Scheme}; use http_body_util::{BodyExt, Full}; use hyper::body::Bytes; use hyper::rt::ReadBufCursor; -use hyper_util::client::legacy::connect::{Connect, HttpConnector}; use hyper_util::client::legacy::Client; +use hyper_util::client::legacy::connect::{Connect, HttpConnector}; use hyper_util::rt::TokioExecutor; use pin_project::pin_project; use tokio::net::TcpStream; @@ -31,6 +31,8 @@ use reactor::gcore::fastedge::{ type HeaderNameList = Vec; +pub const SERVER_NAME_HEADER: &str = "server_name"; + #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum BackendStrategy { Direct, @@ -78,8 +80,8 @@ impl Builder { self } - pub fn hostname(&mut self, hostname: SmolStr) -> &mut Self { - self.hostname = Some(hostname); + pub fn hostname(&mut self, hostname: impl ToSmolStr) -> &mut Self { + self.hostname = Some(hostname.to_smolstr()); self } @@ -146,15 +148,11 @@ impl Backend { pub fn propagate_headers(&mut self, headers: HeaderMap) -> Result<()> { self.propagate_headers.clear(); - if self.strategy == BackendStrategy::FastEdge { - let server_name = headers - .get("server_name") - .and_then(|v| v.to_str().ok()) - .ok_or(anyhow!("header Server_name is missing"))?; - self.propagate_headers.insert( - HeaderName::from_static("host"), - be_base_domain(server_name).parse()?, - ); + if self.strategy == BackendStrategy::FastEdge + && let Some(ref hostname) = self.hostname + { + self.propagate_headers + .insert(header::HOST, hostname.parse()?); } let headers = headers.into_iter().filter(|(k, _)| { if let Some(name) = k { @@ -186,18 +184,16 @@ impl Backend { if !headers .iter() .any(|(k, _)| k.eq_ignore_ascii_case(header::HOST.as_str())) - { - if let Ok(uri) = req.uri.parse::() { - if let Some(host) = uri.authority().map(|a| { - if let Some(port) = a.port() { - format!("{}:{}", a.host(), port) - } else { - a.host().to_string() - } - }) { - headers.push((header::HOST.as_str().to_string(), host)) + && let Ok(uri) = req.uri.parse::() + && let Some(host) = uri.authority().map(|a| { + if let Some(port) = a.port() { + format!("{}:{}", a.host(), port) + } else { + a.host().to_string() } - } + }) + { + headers.push((header::HOST.as_str().to_string(), host)) } let builder = http::Request::builder().uri(req.uri); @@ -370,17 +366,6 @@ where } } -fn be_base_domain(server_name: &str) -> String { - let base_domain = match server_name.find('.') { - None => server_name, - Some(i) => { - let (_, domain) = server_name.split_at(i + 1); - domain - } - }; - format!("be.{}", base_domain) -} - // extract canonical host name fn canonical_host_name(headers: &Headers, original_uri: &Uri) -> Result { let host = headers.iter().find_map(|(k, v)| { @@ -600,9 +585,9 @@ mod tests { let connector = builder.build(); let mut backend = Backend::::builder(BackendStrategy::FastEdge) + .hostname("be.server") .build(connector); - let mut headers = HeaderMap::new(); - headers.insert("Server_name", claims::assert_ok!("server".try_into())); + let headers = HeaderMap::new(); claims::assert_ok!(backend.propagate_headers(headers)); let req = Request { method: Method::Get, @@ -632,9 +617,9 @@ mod tests { let connector = builder.build(); let mut backend = Backend::::builder(BackendStrategy::FastEdge) + .hostname("be.server") .build(connector); - let mut headers = HeaderMap::new(); - headers.insert("Server_name", claims::assert_ok!("server".try_into())); + let headers = HeaderMap::new(); claims::assert_ok!(backend.propagate_headers(headers)); let req = Request { method: Method::Get, @@ -664,9 +649,9 @@ mod tests { let connector = builder.build(); let mut backend = Backend::::builder(BackendStrategy::FastEdge) + .hostname("be.server") .build(connector); - let mut headers = HeaderMap::new(); - headers.insert("Server_name", claims::assert_ok!("server".try_into())); + let headers = HeaderMap::new(); claims::assert_ok!(backend.propagate_headers(headers)); let req = Request { method: Method::Get, @@ -725,9 +710,9 @@ mod tests { let connector = builder.build(); let mut backend = Backend::::builder(BackendStrategy::FastEdge) + .hostname("be.server") .build(connector); - let mut headers = HeaderMap::new(); - headers.insert("Server_name", claims::assert_ok!("server".try_into())); + let headers = HeaderMap::new(); claims::assert_ok!(backend.propagate_headers(headers)); let req = Request { method: Method::Get, @@ -765,10 +750,10 @@ mod tests { let connector = builder.build(); let mut backend = Backend::::builder(BackendStrategy::FastEdge) + .hostname("be.server") .propagate_headers_names(vec!["Propagate-Header".parse().unwrap()]) .build(connector); let mut headers = HeaderMap::new(); - headers.insert("Server_name", claims::assert_ok!("server".try_into())); headers.insert( "No-Propagate-Header", claims::assert_ok!("VALUE".try_into()), @@ -803,11 +788,11 @@ mod tests { let connector = builder.build(); let mut backend = Backend::::builder(BackendStrategy::FastEdge) + .hostname("be.server") .propagate_headers_names(vec!["Propagate-Header".parse().unwrap()]) .uri(assert_ok!("http://be.server/backend_path/".parse())) .build(connector); - let mut headers = HeaderMap::new(); - headers.insert("Server_name", claims::assert_ok!("server".try_into())); + let headers = HeaderMap::new(); claims::assert_ok!(backend.propagate_headers(headers)); let req = Request { @@ -838,11 +823,11 @@ mod tests { let connector = builder.build(); let mut backend = Backend::::builder(BackendStrategy::FastEdge) + .hostname("be.server") .propagate_headers_names(vec!["Propagate-Header".parse().unwrap()]) .max_sub_requests(2) .build(connector); - let mut headers = HeaderMap::new(); - headers.insert("Server_name", claims::assert_ok!("server".try_into())); + let headers = HeaderMap::new(); claims::assert_ok!(backend.propagate_headers(headers)); let req = Request { @@ -987,10 +972,12 @@ mod tests { }; let result = backend.make_request(req); assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("private host not allowed")); + assert!( + result + .unwrap_err() + .to_string() + .contains("private host not allowed") + ); // Test private network let req = Request { @@ -1001,10 +988,12 @@ mod tests { }; let result = backend.make_request(req); assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("private host not allowed")); + assert!( + result + .unwrap_err() + .to_string() + .contains("private host not allowed") + ); // Test another private network let req = Request { @@ -1015,10 +1004,12 @@ mod tests { }; let result = backend.make_request(req); assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("private host not allowed")); + assert!( + result + .unwrap_err() + .to_string() + .contains("private host not allowed") + ); // Test link-local let req = Request { @@ -1029,10 +1020,12 @@ mod tests { }; let result = backend.make_request(req); assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("private host not allowed")); + assert!( + result + .unwrap_err() + .to_string() + .contains("private host not allowed") + ); } #[test] @@ -1050,10 +1043,12 @@ mod tests { }; let result = backend.make_request(req); assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("private host not allowed")); + assert!( + result + .unwrap_err() + .to_string() + .contains("private host not allowed") + ); // Test with Host header containing private IP with port let req = Request { @@ -1064,10 +1059,12 @@ mod tests { }; let result = backend.make_request(req); assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("private host not allowed")); + assert!( + result + .unwrap_err() + .to_string() + .contains("private host not allowed") + ); } #[test] @@ -1085,10 +1082,12 @@ mod tests { }; let result = backend.make_request(req); assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("private host not allowed")); + assert!( + result + .unwrap_err() + .to_string() + .contains("private host not allowed") + ); // Test unique local let req = Request { @@ -1099,10 +1098,12 @@ mod tests { }; let result = backend.make_request(req); assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("private host not allowed")); + assert!( + result + .unwrap_err() + .to_string() + .contains("private host not allowed") + ); // Test link-local let req = Request { @@ -1113,10 +1114,12 @@ mod tests { }; let result = backend.make_request(req); assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("private host not allowed")); + assert!( + result + .unwrap_err() + .to_string() + .contains("private host not allowed") + ); } #[test] diff --git a/crates/http-service/src/executor/http.rs b/crates/http-service/src/executor/http.rs index 7a7480a..cf20a9e 100644 --- a/crates/http-service/src/executor/http.rs +++ b/crates/http-service/src/executor/http.rs @@ -1,7 +1,7 @@ use crate::executor; use crate::executor::HttpExecutor; use crate::state::HttpState; -use anyhow::{anyhow, bail, Context}; +use anyhow::{Context, anyhow, bail}; use async_trait::async_trait; use http::{Method, Request, Response, StatusCode}; use http_backend::Backend; @@ -9,7 +9,7 @@ use http_body_util::{BodyExt, Full}; use hyper::body::Body; use reactor::gcore::fastedge; use runtime::util::stats::{StatsTimer, StatsVisitor}; -use runtime::{store::StoreBuilder, InstancePre}; +use runtime::{InstancePre, store::StoreBuilder}; use std::sync::Arc; use std::time::Duration; use wasmtime_wasi_http::body::HyperOutgoingBody; @@ -176,14 +176,14 @@ mod tests { use super::*; use crate::executor::http::HttpExecutorImpl; use crate::{ - ContextHeaders, ExecutorFactory, HttpService, FASTEDGE_EXECUTION_TIMEOUT, - FASTEDGE_OUT_OF_MEMORY, INTERNAL_STATUS_OUT_OF_MEMORY, INTERNAL_STATUS_TIMEOUT_ELAPSED, + ContextHeaders, ExecutorFactory, FASTEDGE_EXECUTION_TIMEOUT, FASTEDGE_OUT_OF_MEMORY, + HttpService, INTERNAL_STATUS_OUT_OF_MEMORY, INTERNAL_STATUS_TIMEOUT_ELAPSED, INTERNAL_STATUS_TIMEOUT_INTERRUPT, X_CDN_INTERNAL_STATUS, }; use bytes::Bytes; use claims::*; use http_backend::stats::ExtRequestStats; - use http_backend::{Backend, BackendStrategy, FastEdgeConnector}; + use http_backend::{Backend, BackendStrategy, FastEdgeConnector, SERVER_NAME_HEADER}; use http_body_util::Empty; use key_value_store::ReadStats; use runtime::app::{KvStoreOption, SecretOption, Status}; @@ -191,8 +191,8 @@ mod tests { use runtime::service::ServiceBuilder; use runtime::util::stats::CdnPhase; use runtime::{ - componentize_if_necessary, App, ContextT, PreCompiledLoader, Router, WasiVersion, - WasmConfig, WasmEngine, + App, ContextT, PreCompiledLoader, Router, WasiVersion, WasmConfig, WasmEngine, + componentize_if_necessary, }; use secret::SecretStore; use smol_str::{SmolStr, ToSmolStr}; @@ -309,8 +309,12 @@ mod tests { self.app.clone() } - async fn lookup_by_id(&self, _id: u64) -> Option<(SmolStr, App)> { - todo!() + async fn lookup_by_id(&self, id: u64) -> Option<(SmolStr, App)> { + // Mirror the production behaviour: an `Id` is resolved into a (name, App) pair. + // We synthesise a name from the id so tests can assert on it. + self.app + .clone() + .map(|app| (format!("app-{id}").to_smolstr(), app)) } } @@ -401,15 +405,17 @@ mod tests { #[tokio::test] #[tracing_test::traced_test] async fn test_success() { - let req = assert_ok!(Request::builder() - .method("GET") - .uri("http://www.rust-lang.org/") - .header("server_name", "success.test.com") - .body( - Empty::::new() - .map_err(|never| match never {}) - .boxed() - )); + let req = assert_ok!( + Request::builder() + .method("GET") + .uri("http://www.rust-lang.org/") + .header(SERVER_NAME_HEADER, "success.test.com") + .body( + Empty::::new() + .map_err(|never| match never {}) + .boxed() + ) + ); let context = TestContext { geo: load_geo_info(), @@ -436,15 +442,17 @@ mod tests { #[tokio::test] #[tracing_test::traced_test] async fn test_timeout() { - let req = assert_ok!(Request::builder() - .method("GET") - .uri("http://www.rust-lang.org/") - .header("server_name", "timeout.test.com") - .body( - Empty::::new() - .map_err(|never| match never {}) - .boxed() - )); + let req = assert_ok!( + Request::builder() + .method("GET") + .uri("http://www.rust-lang.org/") + .header(SERVER_NAME_HEADER, "timeout.test.com") + .body( + Empty::::new() + .map_err(|never| match never {}) + .boxed() + ) + ); let app = Some(App { binary_id: 1, @@ -498,15 +506,17 @@ mod tests { #[tokio::test] #[tracing_test::traced_test] async fn test_insufficient_memory() { - let req = assert_ok!(Request::builder() - .method("GET") - .uri("http://www.rust-lang.org/?size=200000") - .header("server_name", "insufficient_memory.test.com") - .body( - Empty::::new() - .map_err(|never| match never {}) - .boxed() - )); + let req = assert_ok!( + Request::builder() + .method("GET") + .uri("http://www.rust-lang.org/?size=200000") + .header(SERVER_NAME_HEADER, "insufficient_memory.test.com") + .body( + Empty::::new() + .map_err(|never| match never {}) + .boxed() + ) + ); let app = Some(App { binary_id: 100, @@ -556,15 +566,17 @@ mod tests { #[tokio::test] #[tracing_test::traced_test] async fn draft_app() { - let req = assert_ok!(Request::builder() - .method("GET") - .uri("http://www.rust-lang.org/") - .header("server_name", "draft.test.com") - .body( - Empty::::new() - .map_err(|never| match never {}) - .boxed() - )); + let req = assert_ok!( + Request::builder() + .method("GET") + .uri("http://www.rust-lang.org/") + .header(SERVER_NAME_HEADER, "draft.test.com") + .body( + Empty::::new() + .map_err(|never| match never {}) + .boxed() + ) + ); let context = TestContext { geo: load_geo_info(), @@ -582,15 +594,17 @@ mod tests { #[tokio::test] #[tracing_test::traced_test] async fn disabled_app() { - let req = assert_ok!(Request::builder() - .method("GET") - .uri("http://www.rust-lang.org/") - .header("server_name", "draft.test.com") - .body( - Empty::::new() - .map_err(|never| match never {}) - .boxed() - )); + let req = assert_ok!( + Request::builder() + .method("GET") + .uri("http://www.rust-lang.org/") + .header(SERVER_NAME_HEADER, "draft.test.com") + .body( + Empty::::new() + .map_err(|never| match never {}) + .boxed() + ) + ); let context = TestContext { geo: load_geo_info(), @@ -608,15 +622,17 @@ mod tests { #[tokio::test] #[tracing_test::traced_test] async fn rate_limit_app() { - let req = assert_ok!(Request::builder() - .method("GET") - .uri("http://www.rust-lang.org/") - .header("server_name", "draft.test.com") - .body( - Empty::::new() - .map_err(|never| match never {}) - .boxed() - )); + let req = assert_ok!( + Request::builder() + .method("GET") + .uri("http://www.rust-lang.org/") + .header(SERVER_NAME_HEADER, "draft.test.com") + .body( + Empty::::new() + .map_err(|never| match never {}) + .boxed() + ) + ); let context = TestContext { geo: load_geo_info(), @@ -634,15 +650,17 @@ mod tests { #[tokio::test] #[tracing_test::traced_test] async fn suspended_app() { - let req = assert_ok!(Request::builder() - .method("GET") - .uri("http://www.rust-lang.org/") - .header("server_name", "draft.test.com") - .body( - Empty::::new() - .map_err(|never| match never {}) - .boxed() - )); + let req = assert_ok!( + Request::builder() + .method("GET") + .uri("http://www.rust-lang.org/") + .header(SERVER_NAME_HEADER, "draft.test.com") + .body( + Empty::::new() + .map_err(|never| match never {}) + .boxed() + ) + ); let context = TestContext { geo: load_geo_info(), @@ -656,4 +674,229 @@ mod tests { assert_eq!(StatusCode::NOT_ACCEPTABLE, res.status()); assert_eq!(0, res.headers().len()); } + + // ── handle_request: fastedge_app_id header (Id variant) ────────────── + + /// `fastedge_app_id` is honoured: the request is routed through `lookup_by_id` + /// and reaches the executor exactly like the `server_name`-based path does. + #[tokio::test] + #[tracing_test::traced_test] + async fn test_success_with_fastedge_app_id() { + let req = assert_ok!( + Request::builder() + .method("GET") + .uri("http://www.rust-lang.org/") + .header("fastedge_app_id", "12345") + .body( + Empty::::new() + .map_err(|never| match never {}) + .boxed() + ) + ); + + let context = TestContext { + geo: load_geo_info(), + app: default_test_app(Status::Enabled), + engine: make_engine(), + }; + + let http_service: HttpService = + assert_ok!(ServiceBuilder::new(context).build()); + + let res = assert_ok!(http_service.handle_request("8".to_smolstr(), req).await); + assert_eq!(StatusCode::OK, res.status()); + let headers = res.headers(); + assert_eq!(4, headers.len()); + assert_eq!( + "*", + assert_some!(headers.get("access-control-allow-origin")) + ); + assert_eq!("no-store", assert_some!(headers.get("cache-control"))); + assert_eq!("01", assert_some!(headers.get("RES_HEADER_01"))); + assert_eq!("02", assert_some!(headers.get("RES_HEADER_02"))); + } + + /// `fastedge_app_id` wins over `server_name` when both are present. + #[tokio::test] + #[tracing_test::traced_test] + async fn test_fastedge_app_id_takes_priority_over_server_name() { + let req = assert_ok!( + Request::builder() + .method("GET") + .uri("http://www.rust-lang.org/") + .header("fastedge_app_id", "777") + .header(SERVER_NAME_HEADER, "other.test.com") + .body( + Empty::::new() + .map_err(|never| match never {}) + .boxed() + ) + ); + + let context = TestContext { + geo: load_geo_info(), + app: default_test_app(Status::Enabled), + engine: make_engine(), + }; + + let http_service: HttpService = + assert_ok!(ServiceBuilder::new(context).build()); + + let res = assert_ok!(http_service.handle_request("9".to_smolstr(), req).await); + // Reaching OK proves we resolved via lookup_by_id (otherwise our mock for + // lookup_by_name path would have produced the same status, but we also + // verify in `test_fastedge_app_id_invalid_returns_not_found` below that + // a malformed id short-circuits before touching either lookup). + assert_eq!(StatusCode::OK, res.status()); + } + + /// A non-numeric `fastedge_app_id` header makes `app_name_from_request` fail, + /// which `handle_request` maps to a 404. + #[tokio::test] + #[tracing_test::traced_test] + async fn test_fastedge_app_id_invalid_returns_not_found() { + let req = assert_ok!( + Request::builder() + .method("GET") + .uri("http://www.rust-lang.org/") + .header("fastedge_app_id", "not-a-number") + .body( + Empty::::new() + .map_err(|never| match never {}) + .boxed() + ) + ); + + let context = TestContext { + geo: load_geo_info(), + app: default_test_app(Status::Enabled), + engine: make_engine(), + }; + + let http_service: HttpService = + assert_ok!(ServiceBuilder::new(context).build()); + + let res = assert_ok!(http_service.handle_request("10".to_smolstr(), req).await); + assert_eq!(StatusCode::NOT_FOUND, res.status()); + assert_eq!(0, res.headers().len()); + } + + /// `fastedge_app_id` resolves to an unknown application → 404. + #[tokio::test] + #[tracing_test::traced_test] + async fn test_fastedge_app_id_unknown_app_returns_not_found() { + let req = assert_ok!( + Request::builder() + .method("GET") + .uri("http://www.rust-lang.org/") + .header("fastedge_app_id", "42") + .body( + Empty::::new() + .map_err(|never| match never {}) + .boxed() + ) + ); + + // No app registered in the mock router → lookup_by_id returns None + let context = TestContext { + geo: load_geo_info(), + app: None, + engine: make_engine(), + }; + + let http_service: HttpService = + assert_ok!(ServiceBuilder::new(context).build()); + + let res = assert_ok!(http_service.handle_request("11".to_smolstr(), req).await); + assert_eq!(StatusCode::NOT_FOUND, res.status()); + assert_eq!(0, res.headers().len()); + } + + /// Disabled-status apps reached via `fastedge_app_id` also return 404. + #[tokio::test] + #[tracing_test::traced_test] + async fn test_fastedge_app_id_disabled_returns_not_found() { + let req = assert_ok!( + Request::builder() + .method("GET") + .uri("http://www.rust-lang.org/") + .header("fastedge_app_id", "12345") + .body( + Empty::::new() + .map_err(|never| match never {}) + .boxed() + ) + ); + + let context = TestContext { + geo: load_geo_info(), + app: default_test_app(Status::Disabled), + engine: make_engine(), + }; + + let http_service: HttpService = + assert_ok!(ServiceBuilder::new(context).build()); + + let res = assert_ok!(http_service.handle_request("12".to_smolstr(), req).await); + assert_eq!(StatusCode::NOT_FOUND, res.status()); + assert_eq!(0, res.headers().len()); + } + + // ── handle_request: server_name unknown app ─────────────────────────── + + /// A `server_name`-based request whose app is not registered → 404. + #[tokio::test] + #[tracing_test::traced_test] + async fn test_server_name_unknown_app_returns_not_found() { + let req = assert_ok!( + Request::builder() + .method("GET") + .uri("http://www.rust-lang.org/") + .header(SERVER_NAME_HEADER, "ghost.test.com") + .body( + Empty::::new() + .map_err(|never| match never {}) + .boxed() + ) + ); + + let context = TestContext { + geo: load_geo_info(), + app: None, + engine: make_engine(), + }; + + let http_service: HttpService = + assert_ok!(ServiceBuilder::new(context).build()); + + let res = assert_ok!(http_service.handle_request("13".to_smolstr(), req).await); + assert_eq!(StatusCode::NOT_FOUND, res.status()); + assert_eq!(0, res.headers().len()); + } + + /// No `server_name` and no path segment → `app_name_from_request` errors → 404. + #[tokio::test] + #[tracing_test::traced_test] + async fn test_no_app_name_returns_not_found() { + let req = assert_ok!( + Request::builder().method("GET").uri("/").body( + Empty::::new() + .map_err(|never| match never {}) + .boxed() + ) + ); + + let context = TestContext { + geo: load_geo_info(), + app: default_test_app(Status::Enabled), + engine: make_engine(), + }; + + let http_service: HttpService = + assert_ok!(ServiceBuilder::new(context).build()); + + let res = assert_ok!(http_service.handle_request("14".to_smolstr(), req).await); + assert_eq!(StatusCode::NOT_FOUND, res.status()); + assert_eq!(0, res.headers().len()); + } } diff --git a/crates/http-service/src/executor/wasi_http.rs b/crates/http-service/src/executor/wasi_http.rs index 1141712..0c57624 100644 --- a/crates/http-service/src/executor/wasi_http.rs +++ b/crates/http-service/src/executor/wasi_http.rs @@ -4,17 +4,18 @@ use std::time::Duration; use crate::executor; use crate::executor::HttpExecutor; use crate::state::HttpState; -use ::http::{header, HeaderMap, Request, Response, Uri}; -use anyhow::{anyhow, bail, Context}; +use ::http::{HeaderMap, Request, Response, Uri, header}; +use anyhow::{Context, anyhow, bail}; use async_trait::async_trait; use http_backend::Backend; use http_body_util::{BodyExt, Full}; use hyper::body::Body; use runtime::util::stats::{StatsTimer, StatsVisitor}; -use runtime::{store::StoreBuilder, InstancePre}; -use wasmtime_wasi_http::bindings::http::types::Scheme; +use runtime::{InstancePre, store::StoreBuilder}; +use smol_str::SmolStr; use wasmtime_wasi_http::bindings::ProxyPre; -use wasmtime_wasi_http::{body::HyperOutgoingBody, WasiHttpView}; +use wasmtime_wasi_http::bindings::http::types::Scheme; +use wasmtime_wasi_http::{WasiHttpView, body::HyperOutgoingBody}; /// Execute context used by ['HttpService'] #[derive(Clone)] @@ -44,18 +45,22 @@ where let (sender, receiver) = tokio::sync::oneshot::channel(); let (mut parts, body) = req.into_parts(); - let server_name = parts - .headers - .get("server_name") - .and_then(|v| v.to_str().ok()) - .ok_or(anyhow!("header Server_name is missing"))?; + const LOCALHOST: SmolStr = SmolStr::new_inline("localhost"); + let backend_hostname = self.backend.hostname().unwrap_or(LOCALHOST); + let hostname = match backend_hostname.find('.') { + None => backend_hostname.as_str(), + Some(i) => { + let (_, domain) = backend_hostname.split_at(i + 1); + domain + } + }; // fix relative uri to absolute if parts.uri.scheme().is_none() { let mut uparts = parts.uri.clone().into_parts(); uparts.scheme = Some(::http::uri::Scheme::HTTP); if uparts.authority.is_none() { - uparts.authority = server_name.parse().ok() + uparts.authority = hostname.parse().ok() } parts.uri = Uri::from_parts(uparts)?; } @@ -90,7 +95,7 @@ where }) .collect(); - propagate_headers.insert(header::HOST, be_base_domain(server_name).parse()?); + propagate_headers.insert(header::HOST, backend_hostname.parse()?); let backend_uri = http_backend.uri(); let state = HttpState { @@ -173,14 +178,3 @@ where } } } - -fn be_base_domain(server_name: &str) -> String { - let base_domain = match server_name.find('.') { - None => server_name, - Some(i) => { - let (_, domain) = server_name.split_at(i + 1); - domain - } - }; - format!("be.{}", base_domain) -} diff --git a/crates/http-service/src/lib.rs b/crates/http-service/src/lib.rs index 1c8d571..b1d5ee3 100644 --- a/crates/http-service/src/lib.rs +++ b/crates/http-service/src/lib.rs @@ -1,3 +1,4 @@ +use std::fmt::Display; use std::marker::PhantomData; use std::net::SocketAddr; use std::sync::Arc; @@ -7,12 +8,13 @@ use wasmtime_wasi_nn::wit::WasiNnView; pub use crate::executor::ExecutorFactory; use crate::executor::HttpExecutor; -use anyhow::{bail, Error, Result}; +use anyhow::{Context, Error, Result, bail}; use bytes::Bytes; use http::{ - header::{ACCESS_CONTROL_ALLOW_ORIGIN, CACHE_CONTROL}, HeaderMap, HeaderName, HeaderValue, StatusCode, + header::{ACCESS_CONTROL_ALLOW_ORIGIN, CACHE_CONTROL}, }; +use http_backend::SERVER_NAME_HEADER; use http_body_util::{BodyExt, Empty, Full}; use hyper::{body::Body, server::conn::http1, service::service_fn}; use hyper_util::{client::legacy::connect::Connect, rt::TokioIo}; @@ -20,11 +22,12 @@ use hyper_util::{client::legacy::connect::Connect, rt::TokioIo}; use runtime::util::metrics; use runtime::util::stats::StatsVisitor; use runtime::{ - app::Status, service::Service, App, AppResult, ContextT, Router, WasmEngine, WasmEngineBuilder, + App, AppResult, ContextT, Router, WasmEngine, WasmEngineBuilder, app::Status, service::Service, }; use smol_str::{SmolStr, ToSmolStr}; use state::HttpState; use tokio::{net::TcpListener, time::error::Elapsed}; +use tracing::Instrument; pub use wasmtime_wasi_http::body::HyperOutgoingBody; pub mod executor; @@ -287,26 +290,32 @@ where } Ok(app_name) => app_name, }; - let span = tracing::info_span!("handle", app = app_name.as_str()); + + let span = tracing::info_span!("handle", app = %app_name); let _enter = span.enter(); // lookup for application config and binary_id - tracing::debug!( - "Processing request URL: {}", - request.uri() - ); - let cfg = match self.context.lookup_by_name(&app_name).await { + tracing::debug!("Processing request URL: {}", request.uri()); + let lookup = match app_name { + AppName::Id(id) => self.context.lookup_by_id(id).instrument(span.clone()).await, + AppName::Name(name) => self + .context + .lookup_by_name(&name) + .instrument(span.clone()) + .await + .map(|cfg| (name, cfg)), + }; + + let (app_name, cfg) = match lookup { None => { #[cfg(feature = "metrics")] metrics::metrics(AppResult::UNKNOWN, HTTP_LABEL, None, None); - tracing::info!( - "Request for unknown application '{}' on URL: {}", - app_name, - request.uri() - ); + tracing::info!("Request for unknown application on URL: {}", request.uri()); return not_found(); } - Some(cfg) if cfg.status == Status::Draft || cfg.status == Status::Disabled => { + Some((app_name, cfg)) + if cfg.status == Status::Draft || cfg.status == Status::Disabled => + { tracing::info!( "Request for disabled application '{}' on URL: {}", app_name, @@ -314,7 +323,7 @@ where ); return not_found(); } - Some(cfg) if cfg.status == Status::RateLimited => { + Some((app_name, cfg)) if cfg.status == Status::RateLimited => { tracing::info!( "Request for rate limited application '{}' on URL: {}", app_name, @@ -322,7 +331,7 @@ where ); return too_many_requests(); } - Some(app_cfg) if app_cfg.status == Status::Suspended => { + Some((app_name, cfg)) if cfg.status == Status::Suspended => { tracing::info!( "Request for suspended application '{}' on URL: {}", app_name, @@ -331,7 +340,7 @@ where return not_acceptable(); } - Some(cfg) => cfg, + Some((app_name, cfg)) => (app_name, cfg), }; // get cached execute context for this application @@ -352,7 +361,11 @@ where let stats = self.context.new_stats_row(&request_id, &app_name, &cfg); - let response = match executor.execute(request, stats.clone()).await { + let response = match executor + .execute(request, stats.clone()) + .instrument(span.clone()) + .await + { Ok(mut response) => { #[cfg(feature = "metrics")] metrics::metrics( @@ -522,11 +535,51 @@ fn not_acceptable() -> Result> { .body(Empty::new().map_err(|never| match never {}).boxed())?) } -/// borrows the request and returns the apps name -/// app name can be either as sub-domain in a format '.' (from `Server_name` header) -/// or '/' (from URL) -fn app_name_from_request(req: &hyper::Request) -> Result { - match req.headers().get("server_name") { +#[derive(Debug, Clone)] +pub(crate) enum AppName { + Name(SmolStr), + Id(u64), +} + +impl From for AppName +where + SmolStr: From, +{ + fn from(s: T) -> Self { + AppName::Name(SmolStr::from(s)) + } +} + +impl Display for AppName { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + AppName::Name(name) => write!(f, "{}", name), + AppName::Id(id) => write!(f, "{}", id), + } + } +} + +const FASTEDGE_APP_ID_HEADER: &str = "fastedge_app_id"; + +/// Extracts the application identifier from an incoming HTTP request. +/// +/// Resolution order (first match wins): +/// 1. `fastedge_app_id` header — parsed as a `u64` → [`AppName::Id`] +/// 2. `server_name` header — the leftmost label of the hostname is used as the app name +/// (e.g. `app.example.com` → `"app"`), unless it is `"www"`. +/// 3. URL path — the first path segment is used as the app name +/// (e.g. `/my-app/route` → `"my_app"`; hyphens are normalised to underscores). +/// +/// Returns an error if none of the above yields a non-empty identifier. +fn app_name_from_request(req: &hyper::Request) -> Result { + if let Some(app_id) = req.headers().get(FASTEDGE_APP_ID_HEADER) { + let id = app_id.to_str().context("app_id header is not a string")?; + return Ok(AppName::Id( + id.parse::().context("app_id header is not a number")?, + )); + } + + match req.headers().get(SERVER_NAME_HEADER) { None => {} Some(h) => { let full_hostname = h.to_str().unwrap(); @@ -535,7 +588,7 @@ fn app_name_from_request(req: &hyper::Request) -> Result { Some(i) => { let (prefix, _) = full_hostname.split_at(i); if prefix != "www" { - return Ok(SmolStr::from(prefix)); + return Ok(AppName::from(prefix)); } } } @@ -548,13 +601,13 @@ fn app_name_from_request(req: &hyper::Request) -> Result { } match path.find('/') { - None => Ok(SmolStr::from(path)), + None => Ok(AppName::from(path)), Some(i) => { let (prefix, _) = path.split_at(i); if prefix.contains('-') { - Ok(SmolStr::from(prefix.replace('-', "_"))) + Ok(AppName::from(prefix.replace('-', "_"))) } else { - Ok(SmolStr::from(prefix)) + Ok(AppName::from(prefix)) } } } @@ -616,24 +669,163 @@ pub(crate) mod signal { mod tests { use test_case::test_case; + use crate::AppName; use crate::app_name_from_request; use bytes::Bytes; - use claims::assert_ok; + use claims::{assert_err, assert_ok}; + use http_backend::SERVER_NAME_HEADER; use http_body_util::{BodyExt, Empty}; - #[test_case("app.server.com", "server.com", "app"; "get app name from server_name header")] - fn test_app_name_from_request(server_name: &str, uri: &str, expected: &str) { - let req = assert_ok!(http::Request::builder() - .method("GET") - .uri(uri) - .header("server_name", server_name) - .body( + fn empty_body_request() -> http::request::Builder { + http::Request::builder().method("GET") + } + + // ── Name variant: server_name header ────────────────────────────────── + + #[test_case("app.server.com", "/", "app"; "server_name: normal subdomain")] + #[test_case("foo.example.org", "/ignored", "foo"; "server_name: path is ignored")] + fn test_app_name_from_server_name(server_name: &str, uri: &str, expected: &str) { + let req = assert_ok!( + empty_body_request() + .uri(uri) + .header(SERVER_NAME_HEADER, server_name) + .body( + Empty::::new() + .map_err(|never| match never {}) + .boxed() + ) + ); + let app_name = assert_ok!(app_name_from_request(&req)); + assert!(matches!(&app_name, AppName::Name(n) if n.as_str() == expected)); + } + + #[test] + fn test_app_name_server_name_www_falls_through_to_path() { + // "www" subdomain must be ignored and resolution must fall through to URL path + let req = assert_ok!( + empty_body_request() + .uri("/myapp/route") + .header(SERVER_NAME_HEADER, "www.example.com") + .body( + Empty::::new() + .map_err(|never| match never {}) + .boxed() + ) + ); + let app_name = assert_ok!(app_name_from_request(&req)); + assert!(matches!(&app_name, AppName::Name(n) if n.as_str() == "myapp")); + } + + #[test] + fn test_app_name_server_name_no_dot_falls_through_to_path() { + // hostname without a dot must fall through to URL path + let req = assert_ok!( + empty_body_request() + .uri("/myapp") + .header(SERVER_NAME_HEADER, "localhost") + .body( + Empty::::new() + .map_err(|never| match never {}) + .boxed() + ) + ); + let app_name = assert_ok!(app_name_from_request(&req)); + assert!(matches!(&app_name, AppName::Name(n) if n.as_str() == "myapp")); + } + + // ── Name variant: URL path ──────────────────────────────────────────── + + #[test_case("/myapp", "myapp"; "path only, no subpath")] + #[test_case("/myapp/route", "myapp"; "path with subpath")] + #[test_case("/my-app/route", "my_app"; "hyphens normalised to underscores")] + fn test_app_name_from_path(uri: &str, expected: &str) { + let req = assert_ok!( + empty_body_request().uri(uri).body( Empty::::new() .map_err(|never| match never {}) .boxed() - )); + ) + ); + let app_name = assert_ok!(app_name_from_request(&req)); + assert!(matches!(&app_name, AppName::Name(n) if n.as_str() == expected)); + } + + // ── Id variant: fastedge_app_id header ─────────────────────────────── + #[test] + fn test_app_name_from_app_id_header() { + let req = assert_ok!( + empty_body_request() + .uri("/") + .header("fastedge_app_id", "42") + .body( + Empty::::new() + .map_err(|never| match never {}) + .boxed() + ) + ); let app_name = assert_ok!(app_name_from_request(&req)); - assert_eq!(expected, app_name); + assert!(matches!(app_name, AppName::Id(42))); + } + + #[test] + fn test_app_name_app_id_takes_priority_over_server_name() { + // fastedge_app_id must win over server_name + let req = assert_ok!( + empty_body_request() + .uri("/") + .header("fastedge_app_id", "99") + .header(SERVER_NAME_HEADER, "other.example.com") + .body( + Empty::::new() + .map_err(|never| match never {}) + .boxed() + ) + ); + let app_name = assert_ok!(app_name_from_request(&req)); + assert!(matches!(app_name, AppName::Id(99))); + } + + #[test] + fn test_app_name_app_id_not_a_number_returns_error() { + let req = assert_ok!( + empty_body_request() + .uri("/") + .header("fastedge_app_id", "not-a-number") + .body( + Empty::::new() + .map_err(|never| match never {}) + .boxed() + ) + ); + assert_err!(app_name_from_request(&req)); + } + + // ── Error cases ─────────────────────────────────────────────────────── + + #[test] + fn test_app_name_empty_path_returns_error() { + let req = assert_ok!( + empty_body_request().uri("/").body( + Empty::::new() + .map_err(|never| match never {}) + .boxed() + ) + ); + assert_err!(app_name_from_request(&req)); + } + + // ── Display impl ───────────────────────────────────────────────────── + + #[test] + fn test_app_name_display_name() { + let name = AppName::Name("myapp".into()); + assert_eq!("myapp", name.to_string()); + } + + #[test] + fn test_app_name_display_id() { + let id = AppName::Id(1234); + assert_eq!("1234", id.to_string()); } } diff --git a/crates/http-service/src/state.rs b/crates/http-service/src/state.rs index e780e00..c97b956 100644 --- a/crates/http-service/src/state.rs +++ b/crates/http-service/src/state.rs @@ -1,13 +1,14 @@ use anyhow::Error; use http::request::Parts; use http::uri::Scheme; -use http::{header, HeaderMap, HeaderName, Uri}; -use http_backend::is_public_host; +use http::{HeaderMap, HeaderName, Uri, header}; use http_backend::Backend; +use http_backend::is_public_host; +use runtime::BackendRequest; use runtime::store::HasStats; use runtime::util::stats::StatsVisitor; -use runtime::BackendRequest; use std::sync::Arc; +use tracing::instrument; pub struct HttpState { pub(super) http_backend: Backend, @@ -20,6 +21,7 @@ pub struct HttpState { const FASTEDGE_HEADER_HOSTNAME: &[u8] = b"Fastedge_Header_Hostname"; impl BackendRequest for HttpState { + #[instrument(skip(self, head), level = "debug", ret)] fn backend_request(&mut self, mut head: Parts) -> anyhow::Result { match self.http_backend.strategy { http_backend::BackendStrategy::Direct => { diff --git a/crates/runtime/src/lib.rs b/crates/runtime/src/lib.rs index eb49147..66444b1 100644 --- a/crates/runtime/src/lib.rs +++ b/crates/runtime/src/lib.rs @@ -34,7 +34,7 @@ use crate::util::stats::StatsVisitor; use anyhow::{anyhow, bail}; pub use app::{App, SecretValue, SecretValues}; use http::request::Parts; -use http::{header, HeaderName, Request}; +use http::{HeaderName, Request, header}; use secret::SecretStore; use smol_str::SmolStr; use std::borrow::Cow; @@ -42,7 +42,7 @@ use wasmtime_environ::wasmparser::{Encoding, Parser, Payload}; use wasmtime_wasi_http::body::HyperOutgoingBody; use wasmtime_wasi_http::{ bindings::http::types::ErrorCode, - types::{default_send_request_handler, HostFutureIncomingResponse, OutgoingRequestConfig}, + types::{HostFutureIncomingResponse, OutgoingRequestConfig, default_send_request_handler}, }; use wasmtime_wasi_nn::wit::WasiNnCtx; diff --git a/crates/runtime/src/registry.rs b/crates/runtime/src/registry.rs index 8f1ca49..cb502fa 100644 --- a/crates/runtime/src/registry.rs +++ b/crates/runtime/src/registry.rs @@ -3,14 +3,15 @@ use std::ops::Deref; use std::path::Path; use std::time::Duration; -use anyhow::{anyhow, bail, Context, Result}; +use anyhow::{Context, Result, anyhow, bail}; use moka::sync::Cache; use wasmtime_wasi_nn::backend::candle::CandleBackend; use wasmtime_wasi_nn::wit::types::GraphEncoding; use wasmtime_wasi_nn::{ - backend::{openvino::OpenvinoBackend, BackendFromDir}, + GraphRegistry, Registry, + backend::{BackendFromDir, openvino::OpenvinoBackend}, wit::types::ExecutionTarget, - GraphRegistry, Registry, {Backend, Graph}, + {Backend, Graph}, }; #[derive(Clone)] diff --git a/crates/runtime/src/store.rs b/crates/runtime/src/store.rs index 6941429..6f4acaf 100644 --- a/crates/runtime/src/store.rs +++ b/crates/runtime/src/store.rs @@ -2,7 +2,7 @@ use crate::limiter::ProxyLimiter; use crate::logger::Logger; use crate::registry::CachedGraphRegistry; use crate::util::stats::StatsVisitor; -use crate::{Data, Wasi, WasiVersion, DEFAULT_EPOCH_TICK_INTERVAL}; +use crate::{DEFAULT_EPOCH_TICK_INTERVAL, Data, Wasi, WasiVersion}; use anyhow::Result; use secret::SecretStore; use std::sync::Arc; diff --git a/crates/runtime/src/stub.rs b/crates/runtime/src/stub.rs index 0402a24..0952331 100644 --- a/crates/runtime/src/stub.rs +++ b/crates/runtime/src/stub.rs @@ -1,5 +1,5 @@ -use crate::service::Service; use crate::WasmEngine; +use crate::service::Service; use tokio_util::sync::CancellationToken; pub struct StubService; diff --git a/crates/runtime/src/util/metrics.rs b/crates/runtime/src/util/metrics.rs index ef7dc7a..79c2956 100644 --- a/crates/runtime/src/util/metrics.rs +++ b/crates/runtime/src/util/metrics.rs @@ -1,6 +1,6 @@ use lazy_static::lazy_static; use prometheus::{ - self, register_histogram_vec, register_int_counter_vec, HistogramVec, IntCounterVec, + self, HistogramVec, IntCounterVec, register_histogram_vec, register_int_counter_vec, }; use crate::AppResult; diff --git a/crates/runtime/src/util/stats.rs b/crates/runtime/src/util/stats.rs index afd1e32..b4a7cfa 100644 --- a/crates/runtime/src/util/stats.rs +++ b/crates/runtime/src/util/stats.rs @@ -83,8 +83,8 @@ impl Drop for StatsTimer { #[cfg(test)] mod tests { use super::*; - use std::sync::atomic::{AtomicBool, AtomicI32, AtomicU16, AtomicU64, Ordering}; use std::sync::Mutex; + use std::sync::atomic::{AtomicBool, AtomicI32, AtomicU16, AtomicU64, Ordering}; use std::thread; // Mock implementation of StatsVisitor for testing diff --git a/src/context.rs b/src/context.rs index 7860bb0..2a21cf7 100644 --- a/src/context.rs +++ b/src/context.rs @@ -2,8 +2,8 @@ use crate::cache::MemoryCacheBackend; use crate::executor::RunExecutor; use crate::key_value::CliStoreManager; use crate::secret::SecretImpl; -use http_backend::stats::ExtRequestStats; use http_backend::Backend; +use http_backend::stats::ExtRequestStats; use http_service::executor::{HttpExecutorImpl, WasiHttpExecutorImpl}; use http_service::state::HttpState; use http_service::{ContextHeaders, ExecutorFactory}; @@ -14,14 +14,14 @@ use runtime::app::{KvStoreOption, SecretOption}; use runtime::logger::{Console, Logger}; use runtime::util::stats::{CdnPhase, StatsVisitor}; use runtime::{ - componentize_if_necessary, App, ContextT, ExecutorCache, PreCompiledLoader, Router, - WasiVersion, WasmEngine, + App, ContextT, ExecutorCache, PreCompiledLoader, Router, WasiVersion, WasmEngine, + componentize_if_necessary, }; use secret::SecretStore; use smol_str::SmolStr; use std::collections::HashMap; -use std::sync::atomic::AtomicU64; use std::sync::Arc; +use std::sync::atomic::AtomicU64; use std::time::Duration; use utils::{Dictionary, UserDiagStats}; use wasmtime::component::Component; diff --git a/src/executor.rs b/src/executor.rs index e5dc907..ae05a1c 100644 --- a/src/executor.rs +++ b/src/executor.rs @@ -1,8 +1,8 @@ use async_trait::async_trait; use http::{Request, Response}; use http_body_util::BodyExt; -use http_service::executor::{HttpExecutor, HttpExecutorImpl, WasiHttpExecutorImpl}; use http_service::HyperOutgoingBody; +use http_service::executor::{HttpExecutor, HttpExecutorImpl, WasiHttpExecutorImpl}; use hyper::body::Body; use hyper_tls::HttpsConnector; use hyper_util::client::legacy::connect::HttpConnector; diff --git a/src/main.rs b/src/main.rs index 21012b3..d3cf297 100644 --- a/src/main.rs +++ b/src/main.rs @@ -11,7 +11,7 @@ use bytesize::MB; use clap::{Args, Parser, Subcommand}; use context::Context; use dotenv::{DotEnvInjector, EnvArgType}; -use http_backend::{Backend, BackendStrategy}; +use http_backend::{Backend, BackendStrategy, SERVER_NAME_HEADER}; use http_service::{HttpConfig, HttpService}; use hyper_tls::HttpsConnector; use hyper_util::client::legacy::connect::HttpConnector; @@ -237,9 +237,12 @@ async fn main() -> anyhow::Result<()> { fn append_headers(geo: bool, headers: &mut HashMap) { if !headers .keys() - .any(|k| "server_name".eq_ignore_ascii_case(k)) + .any(|k| SERVER_NAME_HEADER.eq_ignore_ascii_case(k)) { - headers.insert("server_name".to_smolstr(), "test.localhost".to_smolstr()); + headers.insert( + SERVER_NAME_HEADER.to_smolstr(), + "test.localhost".to_smolstr(), + ); } if geo {