From 0d323bbe2d761ac93bbd2fb4368e52a2b155b598 Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 31 Mar 2026 00:55:12 +0000 Subject: [PATCH] feat: axum OpenAI-compatible REST server (--features serve) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Endpoints: /health, /v1/models, /v1/chat/completions, /v1/completions ModelRouter dispatches to GPT-2/OpenChat/future engines. serde_json for request/response serialization. cargo run --bin serve --features serve --release curl http://localhost:3000/v1/models Tested: health ok, models list returns, chat/completions wired. No model weights loaded yet — needs register_gpt2(weights) at startup. https://claude.ai/code/session_01M3at4EuHVvQ8S95mSnKgtK --- Cargo.lock | 291 +++++++++++++++++++++++++++++++++++++++++++++++ Cargo.toml | 13 +++ src/bin/serve.rs | 190 +++++++++++++++++++++++++++++++ 3 files changed, 494 insertions(+) create mode 100644 src/bin/serve.rs diff --git a/Cargo.lock b/Cargo.lock index ac5379bc..f06bc24b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -113,6 +113,12 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "atomic-waker" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" + [[package]] name = "atomic_float" version = "1.1.0" @@ -125,6 +131,58 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" +[[package]] +name = "axum" +version = "0.8.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b52af3cb4058c895d37317bb27508dccc8e5f2d39454016b297bf4a400597b8" +dependencies = [ + "axum-core", + "bytes", + "form_urlencoded", + "futures-util", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-util", + "itoa", + "matchit", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "serde_core", + "serde_json", + "serde_path_to_error", + "serde_urlencoded", + "sync_wrapper", + "tokio", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "axum-core" +version = "0.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08c78f31d7b1291f7ee735c1c6780ccde7785daae9a9206026862dab7d8792d1" +dependencies = [ + "bytes", + "futures-core", + "http", + "http-body", + "http-body-util", + "mime", + "pin-project-lite", + "sync_wrapper", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "backtrace" version = "0.3.76" @@ -1461,6 +1519,24 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" +[[package]] +name = "form_urlencoded" +version = "1.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb4cb245038516f5f85277875cdaa4f7d2c9a0fa0468de06ed190163b1581fcf" +dependencies = [ + "percent-encoding", +] + +[[package]] +name = "futures-channel" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07bbe89c50d7a535e539b8c17bc0b49bdb77747034daa8087407d655f3f7cc1d" +dependencies = [ + "futures-core", +] + [[package]] name = "futures-core" version = "0.3.32" @@ -1694,12 +1770,77 @@ dependencies = [ "itoa", ] +[[package]] +name = "http-body" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" +dependencies = [ + "bytes", + "http", +] + +[[package]] +name = "http-body-util" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a" +dependencies = [ + "bytes", + "futures-core", + "http", + "http-body", + "pin-project-lite", +] + [[package]] name = "httparse" version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" +[[package]] +name = "httpdate" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" + +[[package]] +name = "hyper" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2ab2d4f250c3d7b1c9fcdff1cece94ea4e2dfbec68614f7b87cb205f24ca9d11" +dependencies = [ + "atomic-waker", + "bytes", + "futures-channel", + "futures-core", + "http", + "http-body", + "httparse", + "httpdate", + "itoa", + "pin-project-lite", + "pin-utils", + "smallvec", + "tokio", +] + +[[package]] +name = "hyper-util" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96547c2556ec9d12fb1578c4eaf448b04993e7fb79cbaad930a656880a6bdfa0" +dependencies = [ + "bytes", + "http", + "http-body", + "hyper", + "pin-project-lite", + "tokio", + "tower-service", +] + [[package]] name = "id-arena" version = "2.3.0" @@ -1945,6 +2086,12 @@ dependencies = [ "libc", ] +[[package]] +name = "matchit" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" + [[package]] name = "matrixmultiply" version = "0.3.10" @@ -1970,6 +2117,12 @@ version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" +[[package]] +name = "mime" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" + [[package]] name = "miniz_oxide" version = "0.8.9" @@ -1980,6 +2133,17 @@ dependencies = [ "simd-adler32", ] +[[package]] +name = "mio" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50b7e5b27aa02a74bac8c3f23f448f8d87ff11f92d3aac1a6ed369ee08cc56c1" +dependencies = [ + "libc", + "wasi", + "windows-sys 0.61.2", +] + [[package]] name = "moddef" version = "0.3.0" @@ -2049,6 +2213,7 @@ name = "ndarray" version = "0.17.2" dependencies = [ "approx", + "axum", "blake3", "cblas-sys", "cranelift-codegen", @@ -2071,7 +2236,9 @@ dependencies = [ "rawpointer", "rayon", "serde", + "serde_json", "target-lexicon", + "tokio", ] [[package]] @@ -2476,6 +2643,12 @@ version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a89322df9ebe1c1578d689c92318e070967d1042b512afbe49518723f4e6d5cd" +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" + [[package]] name = "pkg-config" version = "0.3.32" @@ -2909,6 +3082,12 @@ version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" +[[package]] +name = "ryu" +version = "1.0.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9774ba4a74de5f7b1c1451ed6cd5285a32eddb5cccb8cc655a4e50009e06477f" + [[package]] name = "same-file" version = "1.0.6" @@ -3030,6 +3209,17 @@ dependencies = [ "zmij", ] +[[package]] +name = "serde_path_to_error" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10a9ff822e371bb5403e391ecd83e182e0e77ba7f6fe0160b795797109d1b457" +dependencies = [ + "itoa", + "serde", + "serde_core", +] + [[package]] name = "serde_spanned" version = "1.1.0" @@ -3039,6 +3229,18 @@ dependencies = [ "serde_core", ] +[[package]] +name = "serde_urlencoded" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3491c14715ca2294c4d6a88f15e84739788c1d030eed8c110436aafdaa2f3fd" +dependencies = [ + "form_urlencoded", + "itoa", + "ryu", + "serde", +] + [[package]] name = "serialization-tests" version = "0.1.0" @@ -3087,6 +3289,16 @@ dependencies = [ "serde", ] +[[package]] +name = "socket2" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a766e1110788c36f4fa1c2b71b387a7815aa65f88ce0229841826633d93723e" +dependencies = [ + "libc", + "windows-sys 0.61.2", +] + [[package]] name = "spin" version = "0.10.0" @@ -3141,6 +3353,12 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "sync_wrapper" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263" + [[package]] name = "tar" version = "0.4.45" @@ -3228,6 +3446,31 @@ dependencies = [ "serde_json", ] +[[package]] +name = "tokio" +version = "1.50.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "27ad5e34374e03cfffefc301becb44e9dc3c17584f414349ebe29ed26661822d" +dependencies = [ + "libc", + "mio", + "pin-project-lite", + "socket2", + "tokio-macros", + "windows-sys 0.61.2", +] + +[[package]] +name = "tokio-macros" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c55a2eff8b69ce66c84f85e1da1c233edc36ceb85a2058d11b0d6a3c7e7569c" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "toml" version = "1.1.0+spec-1.1.0" @@ -3267,6 +3510,54 @@ version = "1.1.0+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d282ade6016312faf3e41e57ebbba0c073e4056dab1232ab1cb624199648f8ed" +[[package]] +name = "tower" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebe5ef63511595f1344e2d5cfa636d973292adc0eec1f0ad45fae9f0851ab1d4" +dependencies = [ + "futures-core", + "futures-util", + "pin-project-lite", + "sync_wrapper", + "tokio", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "tower-layer" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e" + +[[package]] +name = "tower-service" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" + +[[package]] +name = "tracing" +version = "0.1.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100" +dependencies = [ + "log", + "pin-project-lite", + "tracing-core", +] + +[[package]] +name = "tracing-core" +version = "0.1.36" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db97caf9d906fbde555dd62fa95ddba9eecfd14cb388e4f491a66d74cd5fb79a" +dependencies = [ + "once_cell", +] + [[package]] name = "tynm" version = "0.2.0" diff --git a/Cargo.toml b/Cargo.toml index 782d65c6..551bfafd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,6 +34,11 @@ name = "ndarray" bench = false test = true +[[bin]] +name = "serve" +path = "src/bin/serve.rs" +required-features = ["serve"] + [dependencies] num-integer = { workspace = true } num-traits = { workspace = true } @@ -52,8 +57,13 @@ p64 = { path = "crates/p64" } phyllotactic-manifold = { path = "crates/phyllotactic-manifold" } serde = { version = "1.0", optional = true, default-features = false, features = ["alloc"] } +serde_json = { version = "1", optional = true } rawpointer = { version = "0.2" } +# Axum HTTP server (optional, behind "serve" feature) +axum = { version = "0.8", optional = true } +tokio = { version = "1", features = ["rt-multi-thread", "macros"], optional = true } + # Cranelift JIT (optional, behind "jit-native" feature) # For AVX-512 VPOPCNTDQ/VNNI/VPTERNLOG/BITALG support, use the patched fork: # cranelift-codegen = { git = "https://github.com/AdaWorldAPI/wasmtime.git", branch = "main", optional = true } @@ -90,6 +100,9 @@ jitson = [] # JIT native compilation via Cranelift (jitson_cranelift module) jit-native = ["jitson", "dep:cranelift-codegen", "dep:cranelift-jit", "dep:cranelift-module", "dep:cranelift-frontend", "dep:target-lexicon"] +# Axum HTTP server for OpenAI-compatible API +serve = ["dep:axum", "dep:tokio", "dep:serde", "dep:serde_json"] + # HPC backend feature gates (mutually exclusive) native = [] intel-mkl = [] diff --git a/src/bin/serve.rs b/src/bin/serve.rs new file mode 100644 index 00000000..b947e894 --- /dev/null +++ b/src/bin/serve.rs @@ -0,0 +1,190 @@ +//! OpenAI-compatible REST API server powered by ndarray HPC pipeline. +//! +//! ```bash +//! cargo run --bin serve --features serve --release +//! curl http://localhost:3000/v1/models +//! curl -X POST http://localhost:3000/v1/chat/completions \ +//! -H "Content-Type: application/json" \ +//! -d '{"model":"gpt2","messages":[{"role":"user","content":"Hello"}]}' +//! ``` + +#[cfg(feature = "serve")] +mod server { + use axum::{ + extract::State, + http::StatusCode, + response::Json, + routing::{get, post}, + Router, + }; + use std::sync::Mutex; + + use ndarray::hpc::models::api_types::*; + use ndarray::hpc::models::router::ModelRouter; + + type AppState = std::sync::Arc>; + + async fn list_models(State(state): State) -> Json { + let router = state.lock().unwrap(); + let models = router.list_models(); + Json(serde_json::json!({ + "object": "list", + "data": models.data.iter().map(|m| serde_json::json!({ + "id": m.id, + "object": "model", + "owned_by": m.owned_by, + })).collect::>() + })) + } + + async fn chat_completions( + State(state): State, + Json(req): Json, + ) -> Result, (StatusCode, String)> { + let model = req.get("model").and_then(|v| v.as_str()).unwrap_or("gpt2"); + let messages = req.get("messages").and_then(|v| v.as_array()).cloned().unwrap_or_default(); + let max_tokens = req.get("max_tokens").and_then(|v| v.as_u64()).unwrap_or(64) as usize; + let temperature = req.get("temperature").and_then(|v| v.as_f64()).unwrap_or(0.7) as f32; + + let chat_messages: Vec = messages.iter().filter_map(|m| { + let role_str = m.get("role")?.as_str()?; + let content = m.get("content")?.as_str()?.to_string(); + let role = match role_str { + "system" => ChatRole::System, + "user" => ChatRole::User, + "assistant" => ChatRole::Assistant, + "tool" => ChatRole::Tool, + _ => ChatRole::User, + }; + Some(ChatMessage { + role, + content: Some(content), + name: None, + tool_calls: None, + tool_call_id: None, + }) + }).collect(); + + let chat_req = ChatCompletionRequest { + model: model.to_string(), + messages: chat_messages, + max_tokens: Some(max_tokens), + temperature: Some(temperature), + ..ChatCompletionRequest::default() + }; + + let mut router = state.lock().unwrap(); + match router.chat_complete(&chat_req) { + Ok(resp) => { + let choices: Vec = resp.choices.iter().map(|c| { + serde_json::json!({ + "index": c.index, + "message": { + "role": c.message.role.as_str(), + "content": c.message.content.as_deref().unwrap_or("") + }, + "finish_reason": match &c.finish_reason { + Some(FinishReason::Stop) => "stop", + Some(FinishReason::Length) => "length", + Some(FinishReason::ContentFilter) => "content_filter", + _ => "stop", + } + }) + }).collect(); + + Ok(Json(serde_json::json!({ + "id": resp.id, + "object": "chat.completion", + "model": model, + "choices": choices, + "usage": { + "prompt_tokens": resp.usage.prompt_tokens, + "completion_tokens": resp.usage.completion_tokens, + "total_tokens": resp.usage.total_tokens, + } + }))) + } + Err(e) => Err((StatusCode::BAD_REQUEST, format!("{:?}", e))), + } + } + + async fn completions( + State(state): State, + Json(req): Json, + ) -> Result, (StatusCode, String)> { + let model = req.get("model").and_then(|v| v.as_str()).unwrap_or("gpt2"); + let prompt = req.get("prompt").and_then(|v| v.as_str()).unwrap_or(""); + let max_tokens = req.get("max_tokens").and_then(|v| v.as_u64()).unwrap_or(64) as usize; + let temperature = req.get("temperature").and_then(|v| v.as_f64()).unwrap_or(0.7) as f32; + + let comp_req = CompletionRequest { + model: model.to_string(), + prompt: Some(prompt.to_string()), + max_tokens: Some(max_tokens), + temperature: Some(temperature), + ..CompletionRequest::default() + }; + + let mut router = state.lock().unwrap(); + match router.complete(&comp_req) { + Ok(resp) => { + let choices: Vec = resp.choices.iter().map(|c| { + serde_json::json!({ + "text": c.text, + "index": c.index, + "finish_reason": match &c.finish_reason { + Some(FinishReason::Stop) => "stop", + Some(FinishReason::Length) => "length", + _ => "stop", + }, + }) + }).collect(); + + Ok(Json(serde_json::json!({ + "id": resp.id, + "object": "text_completion", + "model": model, + "choices": choices, + }))) + } + Err(e) => Err((StatusCode::BAD_REQUEST, format!("{:?}", e))), + } + } + + async fn health() -> &'static str { + "ok" + } + + pub async fn run(port: u16) { + let router = ModelRouter::new(); + let state: AppState = std::sync::Arc::new(Mutex::new(router)); + + let app = Router::new() + .route("/health", get(health)) + .route("/v1/models", get(list_models)) + .route("/v1/chat/completions", post(chat_completions)) + .route("/v1/completions", post(completions)) + .with_state(state); + + let addr = format!("0.0.0.0:{port}"); + eprintln!("ndarray serve listening on {addr}"); + let listener = tokio::net::TcpListener::bind(&addr).await.unwrap(); + axum::serve(listener, app).await.unwrap(); + } +} + +#[cfg(feature = "serve")] +#[tokio::main] +async fn main() { + let port: u16 = std::env::args() + .nth(1) + .and_then(|s| s.parse().ok()) + .unwrap_or(3000); + server::run(port).await; +} + +#[cfg(not(feature = "serve"))] +fn main() { + eprintln!("Enable the 'serve' feature: cargo run --bin serve --features serve"); + std::process::exit(1); +}