Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
261 changes: 66 additions & 195 deletions src/hpc/gpt2/api.rs
Original file line number Diff line number Diff line change
@@ -1,222 +1,97 @@
//! OpenAI-compatible API types for GPT-2 inference.
//! GPT-2 API — wraps the inference engine with OpenAI-compatible types.
//!
//! Provides request/response structs matching the OpenAI API surface:
//! Endpoints:
//! - `/v1/completions` — text completion
//! - `/v1/embeddings` — token embeddings via wte
//! - `/v1/models` — model listing
//!
//! These types are transport-agnostic — they serialize/deserialize
//! but don't depend on any HTTP framework.
//! - `/v1/models` — model info

use crate::hpc::models::api_types::*;
use super::inference::{GeneratedToken, Gpt2Engine};
use super::weights::*;

// ============================================================================
// /v1/completions
// ============================================================================

/// Request body for /v1/completions.
#[derive(Clone, Debug)]
pub struct CompletionRequest {
/// Model name (ignored — we only have gpt2).
pub model: String,
/// Input text prompt (will be tokenized externally).
pub prompt_tokens: Vec<u32>,
/// Maximum tokens to generate.
pub max_tokens: usize,
/// Sampling temperature (1.0 = greedy effective).
pub temperature: f32,
/// Stop token ID (default: 50256 = <|endoftext|>).
pub stop_token: Option<u32>,
}

impl Default for CompletionRequest {
fn default() -> Self {
Self {
model: "gpt2".into(),
prompt_tokens: Vec::new(),
max_tokens: 128,
temperature: 1.0,
stop_token: Some(50256),
}
}
}

/// Single completion choice.
#[derive(Clone, Debug)]
pub struct CompletionChoice {
pub index: usize,
pub tokens: Vec<GeneratedToken>,
pub finish_reason: FinishReason,
}

/// Why generation stopped.
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum FinishReason {
Stop,
Length,
}

/// Response body for /v1/completions.
#[derive(Clone, Debug)]
pub struct CompletionResponse {
pub id: String,
pub model: String,
pub choices: Vec<CompletionChoice>,
pub usage: Usage,
}

/// Token usage statistics.
#[derive(Clone, Debug, Default)]
pub struct Usage {
pub prompt_tokens: usize,
pub completion_tokens: usize,
pub total_tokens: usize,
}

// ============================================================================
// /v1/embeddings
// ============================================================================

/// Request body for /v1/embeddings.
#[derive(Clone, Debug)]
pub struct EmbeddingRequest {
pub model: String,
/// Token IDs to embed (one embedding per token).
pub input_tokens: Vec<u32>,
}

/// Single embedding result.
#[derive(Clone, Debug)]
pub struct EmbeddingData {
pub index: usize,
pub embedding: Vec<f32>,
}

/// Response body for /v1/embeddings.
#[derive(Clone, Debug)]
pub struct EmbeddingResponse {
pub model: String,
pub data: Vec<EmbeddingData>,
pub usage: Usage,
}

// ============================================================================
// /v1/models
// ============================================================================

/// Model info for /v1/models.
#[derive(Clone, Debug)]
pub struct ModelInfo {
pub id: String,
pub owned_by: String,
pub vocab_size: usize,
pub embed_dim: usize,
pub num_layers: usize,
pub num_heads: usize,
pub max_seq_len: usize,
}

impl ModelInfo {
/// GPT-2 small (124M) model info.
pub fn gpt2_small() -> Self {
Self {
id: "gpt2".into(),
owned_by: "adaworldapi".into(),
vocab_size: VOCAB_SIZE,
embed_dim: EMBED_DIM,
num_layers: NUM_LAYERS,
num_heads: NUM_HEADS,
max_seq_len: MAX_SEQ_LEN,
}
}
}

// ============================================================================
// Engine wrapper — stateless API over stateful engine
// ============================================================================

/// Stateless API wrapper around Gpt2Engine.
/// Handles request→response conversion.
pub struct Gpt2Api {
engine: Gpt2Engine,
request_counter: u64,
}

impl Gpt2Api {
/// Create from pre-loaded weights.
pub fn new(weights: Gpt2Weights) -> Self {
Self {
engine: Gpt2Engine::new(weights),
request_counter: 0,
}
Self { engine: Gpt2Engine::new(weights), request_counter: 0 }
}

/// /v1/completions handler.
/// `/v1/completions`
pub fn complete(&mut self, req: &CompletionRequest) -> CompletionResponse {
self.request_counter += 1;
let tokens = req.prompt_tokens.as_deref().unwrap_or(&[]);
let max = req.max_tokens.unwrap_or(128);
let temp = req.temperature.unwrap_or(1.0);

let generated = self.engine.generate(
&req.prompt_tokens,
req.max_tokens,
req.temperature,
);
let generated = self.engine.generate(tokens, max, temp);

let finish_reason = if generated.len() < req.max_tokens {
let finish_reason = if generated.len() < max {
FinishReason::Stop
} else {
FinishReason::Length
};

let completion_tokens = generated.len();
let prompt_tokens = req.prompt_tokens.len();

CompletionResponse {
id: format!("cmpl-{}", self.request_counter),
model: "gpt2".into(),
choices: vec![CompletionChoice {
let text = generated.iter().map(|t| format!("[{}]", t.token_id)).collect::<String>();
let logprobs: Vec<LogprobInfo> = generated.iter().map(|t| LogprobInfo {
token: format!("{}", t.token_id),
token_id: t.token_id,
logprob: t.logprob,
bytes: None,
top_logprobs: Vec::new(),
}).collect();

let use_logprobs = req.logprobs.is_some();

CompletionResponse::new(
format!("cmpl-{}", self.request_counter),
"gpt2".into(),
vec![CompletionChoice {
index: 0,
tokens: generated,
finish_reason,
text,
logprobs: if use_logprobs { Some(logprobs) } else { None },
finish_reason: Some(finish_reason),
}],
usage: Usage {
prompt_tokens,
completion_tokens,
total_tokens: prompt_tokens + completion_tokens,
Usage {
prompt_tokens: tokens.len(),
completion_tokens: generated.len(),
total_tokens: tokens.len() + generated.len(),
},
}
0,
)
}

/// /v1/embeddings handler — returns wte embeddings for token IDs.
/// `/v1/embeddings`
pub fn embed(&self, req: &EmbeddingRequest) -> EmbeddingResponse {
let mut data = Vec::with_capacity(req.input_tokens.len());

for (idx, &token_id) in req.input_tokens.iter().enumerate() {
let offset = token_id as usize * EMBED_DIM;
let embedding = self.engine.weights().wte[offset..offset + EMBED_DIM].to_vec();
data.push(EmbeddingData {
index: idx,
embedding,
});
}

EmbeddingResponse {
model: "gpt2".into(),
let token_ids: Vec<u32> = match &req.input {
EmbeddingInput::TokenIds(ids) => ids.clone(),
_ => req.input_tokens.clone().unwrap_or_default(),
};

let data: Vec<EmbeddingData> = token_ids.iter().enumerate().map(|(idx, &tid)| {
let offset = tid as usize * EMBED_DIM;
let mut emb = self.engine.weights().wte[offset..offset + EMBED_DIM].to_vec();
if let Some(dim) = req.dimensions {
emb.truncate(dim);
}
EmbeddingData::new(idx, emb)
}).collect();

EmbeddingResponse::new(
"gpt2".into(),
data,
usage: Usage {
prompt_tokens: req.input_tokens.len(),
completion_tokens: 0,
total_tokens: req.input_tokens.len(),
},
}
Usage { prompt_tokens: token_ids.len(), completion_tokens: 0, total_tokens: token_ids.len() },
)
}

/// /v1/models handler.
pub fn model_info(&self) -> ModelInfo {
ModelInfo::gpt2_small()
/// `/v1/models/{id}`
pub fn model_info() -> Model {
Model::new("gpt2", "adaworldapi", 0)
}

/// Access the underlying engine (for advanced usage).
pub fn engine_mut(&mut self) -> &mut Gpt2Engine {
&mut self.engine
}
Expand All @@ -228,25 +103,21 @@ mod tests {

#[test]
fn test_model_info() {
let info = ModelInfo::gpt2_small();
assert_eq!(info.vocab_size, 50257);
assert_eq!(info.embed_dim, 768);
assert_eq!(info.num_layers, 12);
assert_eq!(info.num_heads, 12);
assert_eq!(info.max_seq_len, 1024);
let m = Gpt2Api::model_info();
assert_eq!(m.id, "gpt2");
assert_eq!(m.object, "model");
}

#[test]
fn test_completion_request_default() {
fn test_completion_defaults() {
let req = CompletionRequest::default();
assert_eq!(req.max_tokens, 128);
assert_eq!(req.temperature, 1.0);
assert_eq!(req.stop_token, Some(50256));
assert_eq!(req.model, "gpt2");
assert_eq!(req.max_tokens, Some(128));
}

#[test]
fn test_finish_reason_variants() {
assert_eq!(FinishReason::Stop, FinishReason::Stop);
assert_ne!(FinishReason::Stop, FinishReason::Length);
fn test_completion_response_object() {
let resp = CompletionResponse::new("x".into(), "gpt2".into(), vec![], Usage::default(), 0);
assert_eq!(resp.object, "text_completion");
}
}
Loading
Loading