diff --git a/src/hpc/gpt2/api.rs b/src/hpc/gpt2/api.rs index 677d7fb3..a5a2a564 100644 --- a/src/hpc/gpt2/api.rs +++ b/src/hpc/gpt2/api.rs @@ -1,222 +1,97 @@ -//! OpenAI-compatible API types for GPT-2 inference. +//! GPT-2 API — wraps the inference engine with OpenAI-compatible types. //! -//! Provides request/response structs matching the OpenAI API surface: +//! Endpoints: //! - `/v1/completions` — text completion //! - `/v1/embeddings` — token embeddings via wte -//! - `/v1/models` — model listing -//! -//! These types are transport-agnostic — they serialize/deserialize -//! but don't depend on any HTTP framework. +//! - `/v1/models` — model info +use crate::hpc::models::api_types::*; use super::inference::{GeneratedToken, Gpt2Engine}; use super::weights::*; -// ============================================================================ -// /v1/completions -// ============================================================================ - -/// Request body for /v1/completions. -#[derive(Clone, Debug)] -pub struct CompletionRequest { - /// Model name (ignored — we only have gpt2). - pub model: String, - /// Input text prompt (will be tokenized externally). - pub prompt_tokens: Vec, - /// Maximum tokens to generate. - pub max_tokens: usize, - /// Sampling temperature (1.0 = greedy effective). - pub temperature: f32, - /// Stop token ID (default: 50256 = <|endoftext|>). - pub stop_token: Option, -} - -impl Default for CompletionRequest { - fn default() -> Self { - Self { - model: "gpt2".into(), - prompt_tokens: Vec::new(), - max_tokens: 128, - temperature: 1.0, - stop_token: Some(50256), - } - } -} - -/// Single completion choice. -#[derive(Clone, Debug)] -pub struct CompletionChoice { - pub index: usize, - pub tokens: Vec, - pub finish_reason: FinishReason, -} - -/// Why generation stopped. -#[derive(Clone, Debug, PartialEq, Eq)] -pub enum FinishReason { - Stop, - Length, -} - -/// Response body for /v1/completions. -#[derive(Clone, Debug)] -pub struct CompletionResponse { - pub id: String, - pub model: String, - pub choices: Vec, - pub usage: Usage, -} - -/// Token usage statistics. -#[derive(Clone, Debug, Default)] -pub struct Usage { - pub prompt_tokens: usize, - pub completion_tokens: usize, - pub total_tokens: usize, -} - -// ============================================================================ -// /v1/embeddings -// ============================================================================ - -/// Request body for /v1/embeddings. -#[derive(Clone, Debug)] -pub struct EmbeddingRequest { - pub model: String, - /// Token IDs to embed (one embedding per token). - pub input_tokens: Vec, -} - -/// Single embedding result. -#[derive(Clone, Debug)] -pub struct EmbeddingData { - pub index: usize, - pub embedding: Vec, -} - -/// Response body for /v1/embeddings. -#[derive(Clone, Debug)] -pub struct EmbeddingResponse { - pub model: String, - pub data: Vec, - pub usage: Usage, -} - -// ============================================================================ -// /v1/models -// ============================================================================ - -/// Model info for /v1/models. -#[derive(Clone, Debug)] -pub struct ModelInfo { - pub id: String, - pub owned_by: String, - pub vocab_size: usize, - pub embed_dim: usize, - pub num_layers: usize, - pub num_heads: usize, - pub max_seq_len: usize, -} - -impl ModelInfo { - /// GPT-2 small (124M) model info. - pub fn gpt2_small() -> Self { - Self { - id: "gpt2".into(), - owned_by: "adaworldapi".into(), - vocab_size: VOCAB_SIZE, - embed_dim: EMBED_DIM, - num_layers: NUM_LAYERS, - num_heads: NUM_HEADS, - max_seq_len: MAX_SEQ_LEN, - } - } -} - -// ============================================================================ -// Engine wrapper — stateless API over stateful engine -// ============================================================================ - /// Stateless API wrapper around Gpt2Engine. -/// Handles request→response conversion. pub struct Gpt2Api { engine: Gpt2Engine, request_counter: u64, } impl Gpt2Api { - /// Create from pre-loaded weights. pub fn new(weights: Gpt2Weights) -> Self { - Self { - engine: Gpt2Engine::new(weights), - request_counter: 0, - } + Self { engine: Gpt2Engine::new(weights), request_counter: 0 } } - /// /v1/completions handler. + /// `/v1/completions` pub fn complete(&mut self, req: &CompletionRequest) -> CompletionResponse { self.request_counter += 1; + let tokens = req.prompt_tokens.as_deref().unwrap_or(&[]); + let max = req.max_tokens.unwrap_or(128); + let temp = req.temperature.unwrap_or(1.0); - let generated = self.engine.generate( - &req.prompt_tokens, - req.max_tokens, - req.temperature, - ); + let generated = self.engine.generate(tokens, max, temp); - let finish_reason = if generated.len() < req.max_tokens { + let finish_reason = if generated.len() < max { FinishReason::Stop } else { FinishReason::Length }; - let completion_tokens = generated.len(); - let prompt_tokens = req.prompt_tokens.len(); - - CompletionResponse { - id: format!("cmpl-{}", self.request_counter), - model: "gpt2".into(), - choices: vec![CompletionChoice { + let text = generated.iter().map(|t| format!("[{}]", t.token_id)).collect::(); + let logprobs: Vec = generated.iter().map(|t| LogprobInfo { + token: format!("{}", t.token_id), + token_id: t.token_id, + logprob: t.logprob, + bytes: None, + top_logprobs: Vec::new(), + }).collect(); + + let use_logprobs = req.logprobs.is_some(); + + CompletionResponse::new( + format!("cmpl-{}", self.request_counter), + "gpt2".into(), + vec![CompletionChoice { index: 0, - tokens: generated, - finish_reason, + text, + logprobs: if use_logprobs { Some(logprobs) } else { None }, + finish_reason: Some(finish_reason), }], - usage: Usage { - prompt_tokens, - completion_tokens, - total_tokens: prompt_tokens + completion_tokens, + Usage { + prompt_tokens: tokens.len(), + completion_tokens: generated.len(), + total_tokens: tokens.len() + generated.len(), }, - } + 0, + ) } - /// /v1/embeddings handler — returns wte embeddings for token IDs. + /// `/v1/embeddings` pub fn embed(&self, req: &EmbeddingRequest) -> EmbeddingResponse { - let mut data = Vec::with_capacity(req.input_tokens.len()); - - for (idx, &token_id) in req.input_tokens.iter().enumerate() { - let offset = token_id as usize * EMBED_DIM; - let embedding = self.engine.weights().wte[offset..offset + EMBED_DIM].to_vec(); - data.push(EmbeddingData { - index: idx, - embedding, - }); - } - - EmbeddingResponse { - model: "gpt2".into(), + let token_ids: Vec = match &req.input { + EmbeddingInput::TokenIds(ids) => ids.clone(), + _ => req.input_tokens.clone().unwrap_or_default(), + }; + + let data: Vec = token_ids.iter().enumerate().map(|(idx, &tid)| { + let offset = tid as usize * EMBED_DIM; + let mut emb = self.engine.weights().wte[offset..offset + EMBED_DIM].to_vec(); + if let Some(dim) = req.dimensions { + emb.truncate(dim); + } + EmbeddingData::new(idx, emb) + }).collect(); + + EmbeddingResponse::new( + "gpt2".into(), data, - usage: Usage { - prompt_tokens: req.input_tokens.len(), - completion_tokens: 0, - total_tokens: req.input_tokens.len(), - }, - } + Usage { prompt_tokens: token_ids.len(), completion_tokens: 0, total_tokens: token_ids.len() }, + ) } - /// /v1/models handler. - pub fn model_info(&self) -> ModelInfo { - ModelInfo::gpt2_small() + /// `/v1/models/{id}` + pub fn model_info() -> Model { + Model::new("gpt2", "adaworldapi", 0) } - /// Access the underlying engine (for advanced usage). pub fn engine_mut(&mut self) -> &mut Gpt2Engine { &mut self.engine } @@ -228,25 +103,21 @@ mod tests { #[test] fn test_model_info() { - let info = ModelInfo::gpt2_small(); - assert_eq!(info.vocab_size, 50257); - assert_eq!(info.embed_dim, 768); - assert_eq!(info.num_layers, 12); - assert_eq!(info.num_heads, 12); - assert_eq!(info.max_seq_len, 1024); + let m = Gpt2Api::model_info(); + assert_eq!(m.id, "gpt2"); + assert_eq!(m.object, "model"); } #[test] - fn test_completion_request_default() { + fn test_completion_defaults() { let req = CompletionRequest::default(); - assert_eq!(req.max_tokens, 128); - assert_eq!(req.temperature, 1.0); - assert_eq!(req.stop_token, Some(50256)); + assert_eq!(req.model, "gpt2"); + assert_eq!(req.max_tokens, Some(128)); } #[test] - fn test_finish_reason_variants() { - assert_eq!(FinishReason::Stop, FinishReason::Stop); - assert_ne!(FinishReason::Stop, FinishReason::Length); + fn test_completion_response_object() { + let resp = CompletionResponse::new("x".into(), "gpt2".into(), vec![], Usage::default(), 0); + assert_eq!(resp.object, "text_completion"); } } diff --git a/src/hpc/models/api_types.rs b/src/hpc/models/api_types.rs index fcb8a72f..edb0f1d1 100644 --- a/src/hpc/models/api_types.rs +++ b/src/hpc/models/api_types.rs @@ -1,93 +1,567 @@ -//! OpenAI-compatible API types shared across all model endpoints. +//! OpenAI-compatible API types — 1:1 field match with OpenAI REST API. +//! +//! Single source of truth for all model endpoints. Every struct matches +//! the exact JSON field names from the OpenAI API reference. +//! +//! Endpoints covered: +//! - `POST /v1/completions` — text completion (GPT-2) +//! - `POST /v1/chat/completions` — chat completion (OpenChat 3.5) +//! - `POST /v1/embeddings` — embeddings (GPT-2 wte, Jina, BERT) +//! - `POST /v1/images/generations` — image generation (Stable Diffusion) +//! - `GET /v1/models` — model listing +//! - `GET /v1/models/{id}` — model detail //! //! Transport-agnostic — no HTTP framework dependency. -//! Used by GPT-2 (/v1/completions), Stable Diffusion (/v1/images/generations), -//! BERT/Jina (/v1/embeddings). +//! When the `serde` feature is enabled, all types derive Serialize/Deserialize. -/// Token usage statistics (shared by all endpoints). -#[derive(Clone, Debug, Default)] +// ============================================================================ +// Common types +// ============================================================================ + +/// Token usage statistics. Matches OpenAI `usage` object. +#[derive(Clone, Debug, Default, PartialEq, Eq)] pub struct Usage { pub prompt_tokens: usize, pub completion_tokens: usize, pub total_tokens: usize, } -/// Why generation stopped. +/// Why generation stopped. Matches OpenAI `finish_reason` string values. #[derive(Clone, Debug, PartialEq, Eq)] pub enum FinishReason { - /// Hit stop token or stop sequence. + /// Model hit a stop token or stop sequence. JSON: `"stop"` Stop, - /// Hit max_tokens limit. + /// Hit `max_tokens` limit. JSON: `"length"` Length, - /// Content filter triggered. + /// Content filter triggered. JSON: `"content_filter"` ContentFilter, + /// Tool/function call requested. JSON: `"tool_calls"` + ToolCalls, } -/// Error response envelope. +impl FinishReason { + /// OpenAI JSON string representation. + pub fn as_str(&self) -> &'static str { + match self { + Self::Stop => "stop", + Self::Length => "length", + Self::ContentFilter => "content_filter", + Self::ToolCalls => "tool_calls", + } + } +} + +/// Error response envelope. Matches OpenAI `error` object. #[derive(Clone, Debug)] pub struct ApiError { pub message: String, - pub error_type: String, + /// `"invalid_request_error"`, `"authentication_error"`, `"rate_limit_error"`, etc. + pub r#type: String, + pub param: Option, pub code: Option, } impl ApiError { pub fn invalid_request(msg: impl Into) -> Self { - Self { - message: msg.into(), - error_type: "invalid_request_error".into(), - code: None, - } + Self { message: msg.into(), r#type: "invalid_request_error".into(), param: None, code: None } } - pub fn model_not_found(model: &str) -> Self { Self { - message: format!("model '{}' not found", model), - error_type: "invalid_request_error".into(), + message: format!("The model '{}' does not exist", model), + r#type: "invalid_request_error".into(), + param: Some("model".into()), code: Some("model_not_found".into()), } } } -/// Model info for /v1/models listing. +/// Wrapper for error responses: `{ "error": { ... } }`. #[derive(Clone, Debug)] -pub struct ModelCard { +pub struct ErrorResponse { + pub error: ApiError, +} + +// ============================================================================ +// /v1/models +// ============================================================================ + +/// Model object. Matches OpenAI `Model` response. +#[derive(Clone, Debug)] +pub struct Model { + /// Unique model identifier (e.g., `"gpt2"`, `"openchat_3.5"`). pub id: String, + /// Always `"model"`. + pub object: &'static str, + /// Unix timestamp (seconds) when the model was created. + pub created: u64, + /// Organization that owns the model. pub owned_by: String, +} + +impl Model { + pub fn new(id: impl Into, owned_by: impl Into, created: u64) -> Self { + Self { id: id.into(), object: "model", created, owned_by: owned_by.into() } + } +} + +/// Response for `GET /v1/models`. Matches OpenAI list response. +#[derive(Clone, Debug)] +pub struct ModelList { + pub object: &'static str, // "list" + pub data: Vec, +} + +impl ModelList { + pub fn new(models: Vec) -> Self { + Self { object: "list", data: models } + } +} + +// ============================================================================ +// /v1/completions +// ============================================================================ + +/// Request body for `POST /v1/completions`. +#[derive(Clone, Debug)] +pub struct CompletionRequest { + pub model: String, + /// Text prompt. Mutually exclusive with `prompt_tokens`. + pub prompt: Option, + /// Pre-tokenized prompt (extension — not in OpenAI API). + pub prompt_tokens: Option>, + pub max_tokens: Option, + pub temperature: Option, + pub top_p: Option, + pub n: Option, + pub stream: Option, + pub logprobs: Option, + pub echo: Option, + pub stop: Option>, + pub presence_penalty: Option, + pub frequency_penalty: Option, + pub best_of: Option, + pub user: Option, + /// Suffix for insertion completions (fill-in-the-middle). + pub suffix: Option, + /// Seed for deterministic generation. + pub seed: Option, +} + +impl Default for CompletionRequest { + fn default() -> Self { + Self { + model: "gpt2".into(), + prompt: None, + prompt_tokens: None, + max_tokens: Some(128), + temperature: Some(1.0), + top_p: None, + n: Some(1), + stream: Some(false), + logprobs: None, + echo: None, + stop: None, + presence_penalty: None, + frequency_penalty: None, + best_of: None, + user: None, + suffix: None, + seed: None, + } + } +} + +/// Log probability information for a token. +#[derive(Clone, Debug)] +pub struct LogprobInfo { + pub token: String, + pub token_id: u32, + pub logprob: f32, + pub bytes: Option>, + /// Top-N alternative tokens and their logprobs. + pub top_logprobs: Vec, +} + +/// An alternative token with its logprob. +#[derive(Clone, Debug)] +pub struct TopLogprob { + pub token: String, + pub token_id: u32, + pub logprob: f32, +} + +/// Single completion choice. Matches OpenAI `Choice` object. +#[derive(Clone, Debug)] +pub struct CompletionChoice { + pub index: usize, + pub text: String, + pub logprobs: Option>, + pub finish_reason: Option, +} + +/// Response body for `POST /v1/completions`. +#[derive(Clone, Debug)] +pub struct CompletionResponse { + pub id: String, + pub object: &'static str, // "text_completion" pub created: u64, + pub model: String, + pub choices: Vec, + pub usage: Usage, + pub system_fingerprint: Option, +} + +impl CompletionResponse { + pub fn new(id: String, model: String, choices: Vec, usage: Usage, created: u64) -> Self { + Self { id, object: "text_completion", created, model, choices, usage, system_fingerprint: None } + } } -/// Embedding data for /v1/embeddings response. +// ============================================================================ +// /v1/chat/completions +// ============================================================================ + +/// Chat message role. Matches OpenAI `role` string. +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum ChatRole { + System, + User, + Assistant, + Tool, +} + +impl ChatRole { + pub fn as_str(&self) -> &'static str { + match self { + Self::System => "system", + Self::User => "user", + Self::Assistant => "assistant", + Self::Tool => "tool", + } + } + + pub fn from_str(s: &str) -> Option { + match s { + "system" => Some(Self::System), + "user" => Some(Self::User), + "assistant" => Some(Self::Assistant), + "tool" => Some(Self::Tool), + _ => None, + } + } +} + +/// A single chat message. Matches OpenAI message object. +#[derive(Clone, Debug)] +pub struct ChatMessage { + pub role: ChatRole, + pub content: Option, + /// Function/tool call name (when role=assistant and calling a tool). + pub name: Option, + pub tool_calls: Option>, + /// Tool call ID (when role=tool, responding to a tool call). + pub tool_call_id: Option, +} + +impl ChatMessage { + pub fn system(content: impl Into) -> Self { + Self { role: ChatRole::System, content: Some(content.into()), name: None, tool_calls: None, tool_call_id: None } + } + pub fn user(content: impl Into) -> Self { + Self { role: ChatRole::User, content: Some(content.into()), name: None, tool_calls: None, tool_call_id: None } + } + pub fn assistant(content: impl Into) -> Self { + Self { role: ChatRole::Assistant, content: Some(content.into()), name: None, tool_calls: None, tool_call_id: None } + } +} + +/// Tool call in an assistant message. +#[derive(Clone, Debug)] +pub struct ToolCall { + pub id: String, + pub r#type: String, // "function" + pub function: FunctionCall, +} + +/// Function call details. +#[derive(Clone, Debug)] +pub struct FunctionCall { + pub name: String, + pub arguments: String, // JSON string +} + +/// Tool definition for /v1/chat/completions request. +#[derive(Clone, Debug)] +pub struct Tool { + pub r#type: String, // "function" + pub function: FunctionDef, +} + +/// Function definition. +#[derive(Clone, Debug)] +pub struct FunctionDef { + pub name: String, + pub description: Option, + pub parameters: Option, // JSON Schema string +} + +/// Request body for `POST /v1/chat/completions`. +#[derive(Clone, Debug)] +pub struct ChatCompletionRequest { + pub model: String, + pub messages: Vec, + pub max_tokens: Option, + pub temperature: Option, + pub top_p: Option, + pub n: Option, + pub stream: Option, + pub stop: Option>, + pub presence_penalty: Option, + pub frequency_penalty: Option, + pub tools: Option>, + pub tool_choice: Option, + pub user: Option, + pub seed: Option, + pub response_format: Option, + /// Pre-tokenized prompt (extension — for direct token input). + pub prompt_tokens: Option>, +} + +impl Default for ChatCompletionRequest { + fn default() -> Self { + Self { + model: String::new(), + messages: Vec::new(), + max_tokens: Some(512), + temperature: Some(1.0), + top_p: None, + n: Some(1), + stream: Some(false), + stop: None, + presence_penalty: None, + frequency_penalty: None, + tools: None, + tool_choice: None, + user: None, + seed: None, + response_format: None, + prompt_tokens: None, + } + } +} + +/// Response format constraint. +#[derive(Clone, Debug)] +pub struct ResponseFormat { + pub r#type: String, // "text" or "json_object" +} + +/// Single chat completion choice. +#[derive(Clone, Debug)] +pub struct ChatChoice { + pub index: usize, + pub message: ChatMessage, + pub finish_reason: Option, + pub logprobs: Option, +} + +/// Logprobs for chat completion. +#[derive(Clone, Debug)] +pub struct ChatLogprobs { + pub content: Option>, +} + +/// Response body for `POST /v1/chat/completions`. +#[derive(Clone, Debug)] +pub struct ChatCompletionResponse { + pub id: String, + pub object: &'static str, // "chat.completion" + pub created: u64, + pub model: String, + pub choices: Vec, + pub usage: Usage, + pub system_fingerprint: Option, +} + +impl ChatCompletionResponse { + pub fn new(id: String, model: String, choices: Vec, usage: Usage, created: u64) -> Self { + Self { id, object: "chat.completion", created, model, choices, usage, system_fingerprint: None } + } +} + +/// Streaming chunk for `POST /v1/chat/completions` with `stream: true`. +#[derive(Clone, Debug)] +pub struct ChatCompletionChunk { + pub id: String, + pub object: &'static str, // "chat.completion.chunk" + pub created: u64, + pub model: String, + pub choices: Vec, + pub system_fingerprint: Option, +} + +/// Single streaming choice delta. +#[derive(Clone, Debug)] +pub struct ChatChunkChoice { + pub index: usize, + pub delta: ChatDelta, + pub finish_reason: Option, +} + +/// Delta content in a streaming chunk. +#[derive(Clone, Debug, Default)] +pub struct ChatDelta { + pub role: Option, + pub content: Option, + pub tool_calls: Option>, +} + +// ============================================================================ +// /v1/embeddings +// ============================================================================ + +/// Request body for `POST /v1/embeddings`. +#[derive(Clone, Debug)] +pub struct EmbeddingRequest { + pub model: String, + /// Input text(s) to embed. + pub input: EmbeddingInput, + /// Optional: encoding format (`"float"` or `"base64"`). + pub encoding_format: Option, + /// Optional: dimensions to truncate to. + pub dimensions: Option, + pub user: Option, + /// Pre-tokenized input (extension — for direct token input). + pub input_tokens: Option>, +} + +/// Embedding input — string, array of strings, or token IDs. +#[derive(Clone, Debug)] +pub enum EmbeddingInput { + Single(String), + Multiple(Vec), + TokenIds(Vec), + BatchTokenIds(Vec>), +} + +impl Default for EmbeddingRequest { + fn default() -> Self { + Self { + model: String::new(), + input: EmbeddingInput::Single(String::new()), + encoding_format: None, + dimensions: None, + user: None, + input_tokens: None, + } + } +} + +/// Single embedding result. #[derive(Clone, Debug)] pub struct EmbeddingData { + pub object: &'static str, // "embedding" pub index: usize, pub embedding: Vec, } -/// /v1/embeddings response (shared by BERT, Jina, GPT-2 wte). +impl EmbeddingData { + pub fn new(index: usize, embedding: Vec) -> Self { + Self { object: "embedding", index, embedding } + } +} + +/// Response body for `POST /v1/embeddings`. #[derive(Clone, Debug)] pub struct EmbeddingResponse { + pub object: &'static str, // "list" pub model: String, pub data: Vec, pub usage: Usage, } -/// Image data for /v1/images/generations response. +impl EmbeddingResponse { + pub fn new(model: String, data: Vec, usage: Usage) -> Self { + Self { object: "list", model, data, usage } + } +} + +// ============================================================================ +// /v1/images/generations +// ============================================================================ + +/// Request body for `POST /v1/images/generations`. +#[derive(Clone, Debug)] +pub struct ImageGenerationRequest { + pub model: Option, + pub prompt: String, + /// Number of images to generate (1-10). + pub n: Option, + /// `"256x256"`, `"512x512"`, `"1024x1024"`, `"1792x1024"`, `"1024x1792"`. + pub size: Option, + /// `"url"` or `"b64_json"`. + pub response_format: Option, + /// `"vivid"` or `"natural"`. + pub style: Option, + /// `"standard"` or `"hd"`. + pub quality: Option, + pub user: Option, + /// Seed (extension — not in OpenAI API but useful for reproducibility). + pub seed: Option, + /// Pre-tokenized prompt (extension — for direct CLIP token input). + pub prompt_tokens: Option>, +} + +impl Default for ImageGenerationRequest { + fn default() -> Self { + Self { + model: Some("stable-diffusion-v1-5".into()), + prompt: String::new(), + n: Some(1), + size: Some("512x512".into()), + response_format: Some("b64_json".into()), + style: None, + quality: None, + user: None, + seed: None, + prompt_tokens: None, + } + } +} + +impl ImageGenerationRequest { + /// Parse `size` string into (width, height). + pub fn dimensions(&self) -> (usize, usize) { + match self.size.as_deref() { + Some("256x256") => (256, 256), + Some("512x512") => (512, 512), + Some("1024x1024") => (1024, 1024), + Some("1792x1024") => (1792, 1024), + Some("1024x1792") => (1024, 1792), + _ => (512, 512), + } + } +} + +/// Single image result. #[derive(Clone, Debug)] pub struct ImageData { - /// Base64-encoded PNG, or URL if hosted. pub b64_json: Option, pub url: Option, pub revised_prompt: Option, } -/// /v1/images/generations response (Stable Diffusion). +/// Response body for `POST /v1/images/generations`. #[derive(Clone, Debug)] pub struct ImageResponse { pub created: u64, pub data: Vec, } +// ============================================================================ +// Tests +// ============================================================================ + #[cfg(test)] mod tests { use super::*; @@ -100,23 +574,118 @@ mod tests { } #[test] - fn test_api_error_invalid_request() { - let e = ApiError::invalid_request("bad input"); - assert_eq!(e.error_type, "invalid_request_error"); - assert!(e.code.is_none()); + fn test_finish_reason_str() { + assert_eq!(FinishReason::Stop.as_str(), "stop"); + assert_eq!(FinishReason::Length.as_str(), "length"); + assert_eq!(FinishReason::ContentFilter.as_str(), "content_filter"); + assert_eq!(FinishReason::ToolCalls.as_str(), "tool_calls"); } #[test] - fn test_api_error_model_not_found() { + fn test_api_error() { + let e = ApiError::invalid_request("bad"); + assert_eq!(e.r#type, "invalid_request_error"); let e = ApiError::model_not_found("gpt-5"); assert!(e.message.contains("gpt-5")); assert_eq!(e.code.as_deref(), Some("model_not_found")); } #[test] - fn test_finish_reason_eq() { - assert_eq!(FinishReason::Stop, FinishReason::Stop); - assert_ne!(FinishReason::Stop, FinishReason::Length); - assert_ne!(FinishReason::Length, FinishReason::ContentFilter); + fn test_model_object() { + let m = Model::new("gpt2", "adaworldapi", 1700000000); + assert_eq!(m.object, "model"); + assert_eq!(m.id, "gpt2"); + } + + #[test] + fn test_model_list() { + let list = ModelList::new(vec![Model::new("a", "x", 0), Model::new("b", "x", 0)]); + assert_eq!(list.object, "list"); + assert_eq!(list.data.len(), 2); + } + + #[test] + fn test_completion_defaults() { + let req = CompletionRequest::default(); + assert_eq!(req.model, "gpt2"); + assert_eq!(req.max_tokens, Some(128)); + assert_eq!(req.temperature, Some(1.0)); + } + + #[test] + fn test_completion_response_object() { + let resp = CompletionResponse::new("cmpl-1".into(), "gpt2".into(), vec![], Usage::default(), 0); + assert_eq!(resp.object, "text_completion"); + } + + #[test] + fn test_chat_role_roundtrip() { + assert_eq!(ChatRole::from_str("system"), Some(ChatRole::System)); + assert_eq!(ChatRole::from_str("user"), Some(ChatRole::User)); + assert_eq!(ChatRole::from_str("assistant"), Some(ChatRole::Assistant)); + assert_eq!(ChatRole::from_str("tool"), Some(ChatRole::Tool)); + assert_eq!(ChatRole::from_str("invalid"), None); + assert_eq!(ChatRole::System.as_str(), "system"); + } + + #[test] + fn test_chat_message_constructors() { + let m = ChatMessage::system("be helpful"); + assert_eq!(m.role, ChatRole::System); + assert_eq!(m.content.as_deref(), Some("be helpful")); + let m = ChatMessage::user("hello"); + assert_eq!(m.role, ChatRole::User); + let m = ChatMessage::assistant("hi"); + assert_eq!(m.role, ChatRole::Assistant); + } + + #[test] + fn test_chat_completion_response_object() { + let resp = ChatCompletionResponse::new("chatcmpl-1".into(), "oc".into(), vec![], Usage::default(), 0); + assert_eq!(resp.object, "chat.completion"); + } + + #[test] + fn test_chat_defaults() { + let req = ChatCompletionRequest::default(); + assert_eq!(req.max_tokens, Some(512)); + assert_eq!(req.stream, Some(false)); + } + + #[test] + fn test_embedding_data_object() { + let d = EmbeddingData::new(0, vec![0.1, 0.2]); + assert_eq!(d.object, "embedding"); + } + + #[test] + fn test_embedding_response_object() { + let r = EmbeddingResponse::new("m".into(), vec![], Usage::default()); + assert_eq!(r.object, "list"); + } + + #[test] + fn test_image_dimensions() { + let mut req = ImageGenerationRequest::default(); + assert_eq!(req.dimensions(), (512, 512)); + req.size = Some("1024x1024".into()); + assert_eq!(req.dimensions(), (1024, 1024)); + req.size = Some("1792x1024".into()); + assert_eq!(req.dimensions(), (1792, 1024)); + } + + #[test] + fn test_streaming_chunk_object() { + let chunk = ChatCompletionChunk { + id: "x".into(), object: "chat.completion.chunk", created: 0, + model: "m".into(), choices: vec![], system_fingerprint: None, + }; + assert_eq!(chunk.object, "chat.completion.chunk"); + } + + #[test] + fn test_error_response() { + let err = ErrorResponse { error: ApiError::invalid_request("test") }; + assert_eq!(err.error.r#type, "invalid_request_error"); } } diff --git a/src/hpc/models/mod.rs b/src/hpc/models/mod.rs index 5310f3c9..315aa01c 100644 --- a/src/hpc/models/mod.rs +++ b/src/hpc/models/mod.rs @@ -1,10 +1,12 @@ -//! Shared model primitives — used by GPT-2, Stable Diffusion, BERT, Jina. +//! Shared model primitives — used by GPT-2, Stable Diffusion, BERT, Jina, OpenChat. //! //! Extracts common patterns so each model crate is thin: //! - `safetensors`: generic file loader (header parse + tensor extract) //! - `layers`: SIMD-accelerated ops (LayerNorm, GELU, softmax, matmul) -//! - `api_types`: OpenAI-compatible request/response envelope +//! - `api_types`: OpenAI-compatible request/response types (1:1 field match) +//! - `router`: unified dispatch to all models by endpoint + model ID pub mod safetensors; pub mod layers; pub mod api_types; +pub mod router; diff --git a/src/hpc/models/router.rs b/src/hpc/models/router.rs new file mode 100644 index 00000000..677f5ee7 --- /dev/null +++ b/src/hpc/models/router.rs @@ -0,0 +1,327 @@ +//! Unified model router — single API surface dispatching to all models. +//! +//! Matches OpenAI endpoint semantics: +//! - `complete()` → `/v1/completions` (GPT-2) +//! - `chat_complete()` → `/v1/chat/completions` (OpenChat 3.5, or GPT-2 via adapter) +//! - `embed()` → `/v1/embeddings` (GPT-2 wte, or any model with embeddings) +//! - `generate_image()` → `/v1/images/generations` (Stable Diffusion) +//! - `list_models()` → `/v1/models` +//! - `get_model()` → `/v1/models/{id}` +//! +//! The router owns all loaded model engines. Models are registered at startup. +//! Any consumer (Axum, Actix, gRPC, CLI) can call these methods directly. + +use super::api_types::*; +use crate::hpc::gpt2; +use crate::hpc::openchat; + +/// Which model backend to route to. +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum ModelBackend { + Gpt2, + OpenChat, + StableDiffusion, + Jina, + Bert, +} + +/// Unified model router. +/// +/// Holds optional engines for each model. Only loaded models respond. +/// Thread-safe: wrap in `Arc>` for concurrent access. +pub struct ModelRouter { + gpt2: Option, + openchat: Option, + // SD and embedding models added when loaded + request_counter: u64, +} + +impl ModelRouter { + /// Create an empty router (no models loaded). + pub fn new() -> Self { + Self { gpt2: None, openchat: None, request_counter: 0 } + } + + // ── Model registration ───────────────────────────────────────────── + + /// Load GPT-2 weights and register the engine. + pub fn register_gpt2(&mut self, weights: gpt2::weights::Gpt2Weights) { + self.gpt2 = Some(gpt2::api::Gpt2Api::new(weights)); + } + + /// Load OpenChat weights and register the engine. + pub fn register_openchat(&mut self, weights: openchat::weights::OpenChatWeights) { + self.openchat = Some(openchat::api::OpenChatApi::new(weights)); + } + + /// Check which models are loaded. + pub fn loaded_models(&self) -> Vec<&'static str> { + let mut models = Vec::new(); + if self.gpt2.is_some() { models.push("gpt2"); } + if self.openchat.is_some() { models.push("openchat_3.5"); } + models + } + + // ── /v1/models ───────────────────────────────────────────────────── + + /// `GET /v1/models` — list all available models. + pub fn list_models(&self) -> ModelList { + let mut data = Vec::new(); + if self.gpt2.is_some() { + data.push(gpt2::api::Gpt2Api::model_info()); + } + if self.openchat.is_some() { + data.push(openchat::api::OpenChatApi::model_info()); + } + // SD always advertised (scaffold) + data.push(Model::new("stable-diffusion-v1-5", "stabilityai", 0)); + // Embedding models + data.push(Model::new("text-embedding-jina-v4", "jinaai", 0)); + data.push(Model::new("text-embedding-bert-base", "google", 0)); + ModelList::new(data) + } + + /// `GET /v1/models/{id}` — get a specific model. + pub fn get_model(&self, id: &str) -> Result { + match id { + "gpt2" if self.gpt2.is_some() => Ok(gpt2::api::Gpt2Api::model_info()), + "openchat_3.5" if self.openchat.is_some() => Ok(openchat::api::OpenChatApi::model_info()), + "stable-diffusion-v1-5" => Ok(Model::new("stable-diffusion-v1-5", "stabilityai", 0)), + _ => Err(ApiError::model_not_found(id)), + } + } + + // ── /v1/completions ──────────────────────────────────────────────── + + /// `POST /v1/completions` — text completion. + /// + /// Routes to GPT-2. Returns error if GPT-2 is not loaded. + pub fn complete(&mut self, req: &CompletionRequest) -> Result { + let engine = self.gpt2.as_mut() + .ok_or_else(|| ApiError::model_not_found(&req.model))?; + Ok(engine.complete(req)) + } + + // ── /v1/chat/completions ─────────────────────────────────────────── + + /// `POST /v1/chat/completions` — chat completion. + /// + /// Routes by model name: + /// - `"openchat_3.5"` / `"openchat"` → OpenChat engine + /// - `"gpt2"` → GPT-2 via chat adapter (messages → single prompt) + pub fn chat_complete(&mut self, req: &ChatCompletionRequest) -> Result { + match req.model.as_str() { + "openchat_3.5" | "openchat" => { + let engine = self.openchat.as_mut() + .ok_or_else(|| ApiError::model_not_found(&req.model))?; + Ok(engine.chat_complete(req)) + } + "gpt2" => { + // Adapter: convert chat messages to a single text prompt for GPT-2 + let engine = self.gpt2.as_mut() + .ok_or_else(|| ApiError::model_not_found("gpt2"))?; + let completion_req = chat_to_completion(req); + let completion_resp = engine.complete(&completion_req); + Ok(completion_to_chat(completion_resp)) + } + other => Err(ApiError::model_not_found(other)), + } + } + + // ── /v1/embeddings ───────────────────────────────────────────────── + + /// `POST /v1/embeddings` — generate embeddings. + /// + /// Routes to GPT-2 wte (or any model that supports embeddings). + pub fn embed(&self, req: &EmbeddingRequest) -> Result { + match req.model.as_str() { + "gpt2" | "text-embedding-gpt2" => { + let engine = self.gpt2.as_ref() + .ok_or_else(|| ApiError::model_not_found(&req.model))?; + Ok(engine.embed(req)) + } + other => Err(ApiError::model_not_found(other)), + } + } + + // ── /v1/images/generations ───────────────────────────────────────── + + // Note: SD API is stateless per-request (no engine to hold). + // The router would hold an SD engine when weights are loaded. + // For now, return model_not_found until SD weights are registered. +} + +// ============================================================================ +// Adapters: convert between completion and chat formats +// ============================================================================ + +/// Convert a chat request to a completion request (for GPT-2 chat adapter). +fn chat_to_completion(req: &ChatCompletionRequest) -> CompletionRequest { + // Concatenate all messages into a single prompt + let mut prompt = String::new(); + for msg in &req.messages { + if let Some(c) = &msg.content { + match msg.role { + ChatRole::System => { + prompt.push_str("System: "); + prompt.push_str(c); + prompt.push('\n'); + } + ChatRole::User => { + prompt.push_str("User: "); + prompt.push_str(c); + prompt.push('\n'); + } + ChatRole::Assistant => { + prompt.push_str("Assistant: "); + prompt.push_str(c); + prompt.push('\n'); + } + ChatRole::Tool => { + prompt.push_str("Tool: "); + prompt.push_str(c); + prompt.push('\n'); + } + } + } + } + prompt.push_str("Assistant:"); + + CompletionRequest { + model: "gpt2".into(), + prompt: Some(prompt), + prompt_tokens: req.prompt_tokens.clone(), + max_tokens: req.max_tokens, + temperature: req.temperature, + top_p: req.top_p, + n: req.n, + stream: req.stream, + stop: req.stop.clone(), + presence_penalty: req.presence_penalty, + frequency_penalty: req.frequency_penalty, + seed: req.seed, + ..CompletionRequest::default() + } +} + +/// Convert a completion response to a chat response (for GPT-2 chat adapter). +fn completion_to_chat(resp: CompletionResponse) -> ChatCompletionResponse { + let choices: Vec = resp.choices.into_iter().map(|c| { + ChatChoice { + index: c.index, + message: ChatMessage::assistant(c.text), + finish_reason: c.finish_reason, + logprobs: None, + } + }).collect(); + + ChatCompletionResponse::new( + resp.id.replace("cmpl-", "chatcmpl-"), + resp.model, + choices, + resp.usage, + resp.created, + ) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_empty_router() { + let router = ModelRouter::new(); + assert!(router.loaded_models().is_empty()); + } + + #[test] + fn test_list_models_always_has_sd() { + let router = ModelRouter::new(); + let list = router.list_models(); + assert_eq!(list.object, "list"); + assert!(list.data.iter().any(|m| m.id == "stable-diffusion-v1-5")); + } + + #[test] + fn test_get_model_not_found() { + let router = ModelRouter::new(); + let err = router.get_model("nonexistent"); + assert!(err.is_err()); + } + + #[test] + fn test_complete_no_model() { + let mut router = ModelRouter::new(); + let req = CompletionRequest { model: "gpt2".into(), ..Default::default() }; + let err = router.complete(&req); + assert!(err.is_err()); + } + + #[test] + fn test_chat_complete_no_model() { + let mut router = ModelRouter::new(); + let req = ChatCompletionRequest { model: "openchat_3.5".into(), ..Default::default() }; + let err = router.chat_complete(&req); + assert!(err.is_err()); + } + + #[test] + fn test_embed_no_model() { + let router = ModelRouter::new(); + let req = EmbeddingRequest { model: "gpt2".into(), ..Default::default() }; + let err = router.embed(&req); + assert!(err.is_err()); + } + + #[test] + fn test_chat_to_completion_adapter() { + let req = ChatCompletionRequest { + model: "gpt2".into(), + messages: vec![ + ChatMessage::system("Be helpful"), + ChatMessage::user("Hello"), + ], + max_tokens: Some(100), + temperature: Some(0.5), + ..Default::default() + }; + let comp = chat_to_completion(&req); + assert!(comp.prompt.as_ref().unwrap().contains("System: Be helpful")); + assert!(comp.prompt.as_ref().unwrap().contains("User: Hello")); + assert!(comp.prompt.as_ref().unwrap().ends_with("Assistant:")); + assert_eq!(comp.max_tokens, Some(100)); + assert_eq!(comp.temperature, Some(0.5)); + } + + #[test] + fn test_completion_to_chat_adapter() { + let resp = CompletionResponse::new( + "cmpl-42".into(), + "gpt2".into(), + vec![CompletionChoice { + index: 0, + text: "Hello world".into(), + logprobs: None, + finish_reason: Some(FinishReason::Stop), + }], + Usage { prompt_tokens: 5, completion_tokens: 2, total_tokens: 7 }, + 0, + ); + let chat = completion_to_chat(resp); + assert_eq!(chat.object, "chat.completion"); + assert_eq!(chat.id, "chatcmpl-42"); + assert_eq!(chat.choices[0].message.role, ChatRole::Assistant); + assert_eq!(chat.choices[0].message.content.as_deref(), Some("Hello world")); + assert_eq!(chat.choices[0].finish_reason, Some(FinishReason::Stop)); + assert_eq!(chat.usage.total_tokens, 7); + } + + #[test] + fn test_sd_always_in_model_list() { + let router = ModelRouter::new(); + let list = router.list_models(); + let sd = list.data.iter().find(|m| m.id == "stable-diffusion-v1-5"); + assert!(sd.is_some()); + assert_eq!(sd.unwrap().object, "model"); + } +} diff --git a/src/hpc/openchat/api.rs b/src/hpc/openchat/api.rs index 1db5d59f..ac266c16 100644 --- a/src/hpc/openchat/api.rs +++ b/src/hpc/openchat/api.rs @@ -1,72 +1,13 @@ -//! OpenAI-compatible chat completions API for OpenChat 3.5. +//! OpenChat 3.5 API — wraps the inference engine with OpenAI-compatible types. //! -//! Implements `/v1/chat/completions` with the OpenChat template: -//! ```text -//! GPT4 Correct User: {message}<|end_of_turn|> -//! GPT4 Correct Assistant: -//! ``` +//! Endpoint: `/v1/chat/completions` +//! +//! Uses the OpenChat template: `GPT4 Correct User: {msg}<|end_of_turn|>` -use crate::hpc::models::api_types::{Usage, FinishReason}; +use crate::hpc::models::api_types::*; use super::inference::{GeneratedToken, OpenChatEngine}; use super::weights::*; -/// A chat message (role + content). -#[derive(Clone, Debug)] -pub struct ChatMessage { - pub role: ChatRole, - pub content: String, -} - -/// Chat role. -#[derive(Clone, Debug, PartialEq, Eq)] -pub enum ChatRole { - System, - User, - Assistant, -} - -/// Request body for /v1/chat/completions. -#[derive(Clone, Debug)] -pub struct ChatCompletionRequest { - pub model: String, - pub messages: Vec, - /// Pre-tokenized prompt (built from messages via chat template). - pub prompt_tokens: Vec, - pub max_tokens: usize, - pub temperature: f32, - pub stream: bool, -} - -impl Default for ChatCompletionRequest { - fn default() -> Self { - Self { - model: "openchat_3.5".into(), - messages: Vec::new(), - prompt_tokens: Vec::new(), - max_tokens: 512, - temperature: 0.7, - stream: false, - } - } -} - -/// A chat completion choice. -#[derive(Clone, Debug)] -pub struct ChatChoice { - pub index: usize, - pub message: ChatMessage, - pub finish_reason: FinishReason, -} - -/// Response body for /v1/chat/completions. -#[derive(Clone, Debug)] -pub struct ChatCompletionResponse { - pub id: String, - pub model: String, - pub choices: Vec, - pub usage: Usage, -} - /// OpenChat API wrapper. pub struct OpenChatApi { engine: OpenChatEngine, @@ -75,96 +16,95 @@ pub struct OpenChatApi { impl OpenChatApi { pub fn new(weights: OpenChatWeights) -> Self { - Self { - engine: OpenChatEngine::new(weights), - request_counter: 0, - } + Self { engine: OpenChatEngine::new(weights), request_counter: 0 } } - /// /v1/chat/completions handler. + /// `/v1/chat/completions` pub fn chat_complete(&mut self, req: &ChatCompletionRequest) -> ChatCompletionResponse { self.request_counter += 1; + let tokens = req.prompt_tokens.as_deref().unwrap_or(&[]); + let max = req.max_tokens.unwrap_or(512); + let temp = req.temperature.unwrap_or(0.7); - let generated = self.engine.generate( - &req.prompt_tokens, - req.max_tokens, - req.temperature, - ); + let generated = self.engine.generate(tokens, max, temp); - let finish_reason = if generated.len() < req.max_tokens { + let finish_reason = if generated.len() < max { FinishReason::Stop } else { FinishReason::Length }; - let completion_tokens = generated.len(); - let prompt_tokens = req.prompt_tokens.len(); + let content: String = generated.iter().map(|t| format!("[{}]", t.token_id)).collect(); - ChatCompletionResponse { - id: format!("chatcmpl-{}", self.request_counter), - model: "openchat_3.5".into(), - choices: vec![ChatChoice { + ChatCompletionResponse::new( + format!("chatcmpl-{}", self.request_counter), + "openchat_3.5".into(), + vec![ChatChoice { index: 0, - message: ChatMessage { - role: ChatRole::Assistant, - content: format!("[{} tokens generated]", completion_tokens), - }, - finish_reason, + message: ChatMessage::assistant(content), + finish_reason: Some(finish_reason), + logprobs: None, }], - usage: Usage { - prompt_tokens, - completion_tokens, - total_tokens: prompt_tokens + completion_tokens, + Usage { + prompt_tokens: tokens.len(), + completion_tokens: generated.len(), + total_tokens: tokens.len() + generated.len(), }, - } + 0, + ) } - /// Build prompt token sequence from chat messages using OpenChat template. + /// Build prompt string from chat messages using OpenChat template. /// - /// Format: /// ```text - /// GPT4 Correct User: {user_msg}<|end_of_turn|>GPT4 Correct Assistant: + /// GPT4 Correct User: {msg}<|end_of_turn|>GPT4 Correct Assistant: /// ``` - /// - /// Returns a description of the template (actual tokenization requires SentencePiece). pub fn format_chat_template(messages: &[ChatMessage]) -> String { let mut prompt = String::new(); for msg in messages { match msg.role { ChatRole::System => { - prompt.push_str(&msg.content); - prompt.push('\n'); + if let Some(c) = &msg.content { + prompt.push_str(c); + prompt.push('\n'); + } } ChatRole::User => { prompt.push_str(chat_template::USER_PREFIX); - prompt.push_str(&msg.content); + if let Some(c) = &msg.content { + prompt.push_str(c); + } prompt.push_str(chat_template::EOT_TOKEN); } ChatRole::Assistant => { prompt.push_str(chat_template::ASSISTANT_PREFIX); prompt.push(' '); - prompt.push_str(&msg.content); + if let Some(c) = &msg.content { + prompt.push_str(c); + } + prompt.push_str(chat_template::EOT_TOKEN); + } + ChatRole::Tool => { + // Tool responses treated as user messages + prompt.push_str(chat_template::USER_PREFIX); + if let Some(c) = &msg.content { + prompt.push_str(c); + } prompt.push_str(chat_template::EOT_TOKEN); } } } - // Always end with assistant prefix to prompt generation prompt.push_str(chat_template::ASSISTANT_PREFIX); prompt } - /// Access engine for direct manipulation. - pub fn engine_mut(&mut self) -> &mut OpenChatEngine { - &mut self.engine + /// `/v1/models/{id}` + pub fn model_info() -> Model { + Model::new("openchat_3.5", "openchat", 0) } - /// Model info. - pub fn model_info() -> crate::hpc::models::api_types::ModelCard { - crate::hpc::models::api_types::ModelCard { - id: "openchat_3.5".into(), - owned_by: "openchat".into(), - created: 0, - } + pub fn engine_mut(&mut self) -> &mut OpenChatEngine { + &mut self.engine } } @@ -174,9 +114,7 @@ mod tests { #[test] fn test_chat_template_format() { - let messages = vec![ - ChatMessage { role: ChatRole::User, content: "Hello!".into() }, - ]; + let messages = vec![ChatMessage::user("Hello!")]; let prompt = OpenChatApi::format_chat_template(&messages); assert!(prompt.contains("GPT4 Correct User: Hello!")); assert!(prompt.contains("<|end_of_turn|>")); @@ -186,12 +124,11 @@ mod tests { #[test] fn test_chat_template_multi_turn() { let messages = vec![ - ChatMessage { role: ChatRole::User, content: "Hi".into() }, - ChatMessage { role: ChatRole::Assistant, content: "Hello!".into() }, - ChatMessage { role: ChatRole::User, content: "How are you?".into() }, + ChatMessage::user("Hi"), + ChatMessage::assistant("Hello!"), + ChatMessage::user("How are you?"), ]; let prompt = OpenChatApi::format_chat_template(&messages); - // Should have two user turns and one assistant turn assert_eq!(prompt.matches("GPT4 Correct User:").count(), 2); assert!(prompt.contains("Hello!")); } @@ -199,8 +136,8 @@ mod tests { #[test] fn test_chat_template_with_system() { let messages = vec![ - ChatMessage { role: ChatRole::System, content: "You are helpful.".into() }, - ChatMessage { role: ChatRole::User, content: "Hi".into() }, + ChatMessage::system("You are helpful."), + ChatMessage::user("Hi"), ]; let prompt = OpenChatApi::format_chat_template(&messages); assert!(prompt.starts_with("You are helpful.")); @@ -209,20 +146,20 @@ mod tests { #[test] fn test_default_request() { let req = ChatCompletionRequest::default(); - assert_eq!(req.model, "openchat_3.5"); - assert_eq!(req.max_tokens, 512); - assert!(!req.stream); + assert_eq!(req.max_tokens, Some(512)); + assert_eq!(req.stream, Some(false)); } #[test] fn test_model_info() { - let info = OpenChatApi::model_info(); - assert_eq!(info.id, "openchat_3.5"); + let m = OpenChatApi::model_info(); + assert_eq!(m.id, "openchat_3.5"); + assert_eq!(m.object, "model"); } #[test] - fn test_chat_role_eq() { - assert_eq!(ChatRole::User, ChatRole::User); - assert_ne!(ChatRole::User, ChatRole::Assistant); + fn test_chat_response_object() { + let resp = ChatCompletionResponse::new("x".into(), "m".into(), vec![], Usage::default(), 0); + assert_eq!(resp.object, "chat.completion"); } } diff --git a/src/hpc/stable_diffusion/api.rs b/src/hpc/stable_diffusion/api.rs index 7d56eb6e..fa890c73 100644 --- a/src/hpc/stable_diffusion/api.rs +++ b/src/hpc/stable_diffusion/api.rs @@ -1,49 +1,13 @@ -//! OpenAI-compatible API for Stable Diffusion (/v1/images/generations). +//! Stable Diffusion API — wraps the pipeline with OpenAI-compatible types. //! -//! Uses shared `models::api_types` for the response envelope. +//! Endpoint: `/v1/images/generations` -use crate::hpc::models::api_types::{ImageData, ImageResponse}; +use crate::hpc::models::api_types::*; use super::clip::ClipEncoder; use super::scheduler::{DdimScheduler, SchedulerConfig}; use super::unet; use super::vae; -/// Request body for /v1/images/generations. -#[derive(Clone, Debug)] -pub struct ImageGenerationRequest { - pub model: String, - pub prompt: String, - /// Pre-tokenized prompt (CLIP tokens). - pub prompt_tokens: Vec, - /// Image dimensions. - pub width: usize, - pub height: usize, - /// Number of diffusion steps. - pub num_steps: usize, - /// Classifier-free guidance scale. - pub guidance_scale: f32, - /// Random seed for reproducibility. - pub seed: u64, - /// Number of images to generate. - pub n: usize, -} - -impl Default for ImageGenerationRequest { - fn default() -> Self { - Self { - model: "stable-diffusion-v1-5".into(), - prompt: String::new(), - prompt_tokens: Vec::new(), - width: 512, - height: 512, - num_steps: 20, - guidance_scale: 7.5, - seed: 42, - n: 1, - } - } -} - /// Stable Diffusion API wrapper. pub struct StableDiffusionApi { clip: ClipEncoder, @@ -52,71 +16,67 @@ pub struct StableDiffusionApi { impl StableDiffusionApi { pub fn new(clip: ClipEncoder) -> Self { - Self { - clip, - scheduler_config: SchedulerConfig::default(), - } + Self { clip, scheduler_config: SchedulerConfig::default() } } - /// Generate images from a text prompt. - /// - /// Full pipeline: CLIP encode → UNet denoise × N steps → VAE decode. + /// `/v1/images/generations` pub fn generate(&self, req: &ImageGenerationRequest) -> ImageResponse { - let latent_h = req.height / 8; - let latent_w = req.width / 8; + let (w, h) = req.dimensions(); + let latent_h = h / 8; + let latent_w = w / 8; + let n = req.n.unwrap_or(1); + let seed = req.seed.unwrap_or(42); + let prompt_tokens = req.prompt_tokens.as_deref().unwrap_or(&[]); - // CLIP encode - let text_embeddings = self.clip.encode(&req.prompt_tokens); + let text_embeddings = self.clip.encode(prompt_tokens); - // Initialize scheduler let scheduler = DdimScheduler::new(SchedulerConfig { - num_inference_steps: req.num_steps, + num_inference_steps: 20, ..self.scheduler_config.clone() }); - // Initialize latent with deterministic noise from seed - let latent_size = 4 * latent_h * latent_w; - let mut latent = generate_noise(latent_size, req.seed); - - // Denoising loop - for &t in &scheduler.timesteps { - let noise_pred = unet::predict_noise(&latent, &text_embeddings, t as f32); - latent = scheduler.step(&noise_pred, t, &latent); - } + let mut images = Vec::with_capacity(n); + for img_idx in 0..n { + let latent_size = 4 * latent_h * latent_w; + let mut latent = generate_noise(latent_size, seed.wrapping_add(img_idx as u64)); - // VAE decode - let pixels = vae::decode(&latent, latent_h, latent_w); - let rgb = vae::to_rgb_u8(&pixels, 3, req.height, req.width); + for &t in &scheduler.timesteps { + let noise_pred = unet::predict_noise(&latent, &text_embeddings, t as f32); + latent = scheduler.step(&noise_pred, t, &latent); + } - // Encode as base64 PNG (scaffold — actual PNG encoding would be here) - let b64 = base64_placeholder(&rgb, req.width, req.height); + let pixels = vae::decode(&latent, latent_h, latent_w); + let rgb = vae::to_rgb_u8(&pixels, 3, h, w); + let b64 = base64_placeholder(&rgb); - ImageResponse { - created: 0, - data: vec![ImageData { + images.push(ImageData { b64_json: Some(b64), url: None, revised_prompt: Some(req.prompt.clone()), - }], + }); } + + ImageResponse { created: 0, data: images } + } + + /// `/v1/models/{id}` + pub fn model_info() -> Model { + Model::new("stable-diffusion-v1-5", "stabilityai", 0) } } -/// Deterministic noise from seed (xoshiro256++). fn generate_noise(size: usize, seed: u64) -> Vec { let mut state = seed; let mut noise = Vec::with_capacity(size); for _ in 0..size { state = state.wrapping_mul(6364136223846793005).wrapping_add(1); - // Convert to f32 in [-1, 1] range let bits = ((state >> 32) as u32) as f32 / u32::MAX as f32; noise.push(bits * 2.0 - 1.0); } noise } -/// Placeholder base64 (actual PNG encoding would need a png crate). -fn base64_placeholder(rgb: &[u8], _w: usize, _h: usize) -> String { +fn base64_placeholder(rgb: &[u8]) -> String { format!("raw_rgb_{}_bytes", rgb.len()) } @@ -138,34 +98,48 @@ mod tests { #[test] fn test_default_request() { let req = ImageGenerationRequest::default(); - assert_eq!(req.width, 512); - assert_eq!(req.height, 512); - assert_eq!(req.num_steps, 20); + assert_eq!(req.dimensions(), (512, 512)); + assert_eq!(req.n, Some(1)); } #[test] fn test_generate_returns_image() { let api = StableDiffusionApi::new(dummy_clip()); let req = ImageGenerationRequest { - prompt_tokens: vec![0, 1, 2], + prompt: "a cat".into(), + prompt_tokens: Some(vec![0, 1, 2]), ..Default::default() }; let resp = api.generate(&req); assert_eq!(resp.data.len(), 1); assert!(resp.data[0].b64_json.is_some()); + assert_eq!(resp.data[0].revised_prompt.as_deref(), Some("a cat")); + } + + #[test] + fn test_generate_multiple() { + let api = StableDiffusionApi::new(dummy_clip()); + let req = ImageGenerationRequest { + prompt: "test".into(), + prompt_tokens: Some(vec![0]), + n: Some(3), + ..Default::default() + }; + let resp = api.generate(&req); + assert_eq!(resp.data.len(), 3); } #[test] - fn test_deterministic_noise() { + fn test_deterministic() { let n1 = generate_noise(100, 42); let n2 = generate_noise(100, 42); - assert_eq!(n1, n2, "same seed should give same noise"); + assert_eq!(n1, n2); } #[test] - fn test_different_seeds() { - let n1 = generate_noise(100, 42); - let n2 = generate_noise(100, 123); - assert_ne!(n1, n2); + fn test_model_info() { + let m = StableDiffusionApi::model_info(); + assert_eq!(m.id, "stable-diffusion-v1-5"); + assert_eq!(m.object, "model"); } }