diff --git a/src/hpc/gguf.rs b/src/hpc/gguf.rs index d969e305..874d0a63 100644 --- a/src/hpc/gguf.rs +++ b/src/hpc/gguf.rs @@ -225,6 +225,12 @@ pub fn read_tensor_f32( GgmlType::Q8_0 => { dequantize_q8_0(reader, n_elements) } + GgmlType::Q4_0 => { + dequantize_q4_0(reader, n_elements) + } + GgmlType::Q4_K => { + dequantize_q4_k(reader, n_elements) + } other => Err(format!("Unsupported dtype for dequantization: {:?}", other)), } } @@ -317,6 +323,90 @@ fn dequantize_q8_0(r: &mut R, n_elements: usize) -> Result, St Ok(result) } +/// Dequantize Q4_0: each block = 2 bytes scale (f16) + 16 bytes (32 nibbles). +fn dequantize_q4_0(r: &mut R, n_elements: usize) -> Result, String> { + let block_size = 32; + let n_blocks = (n_elements + block_size - 1) / block_size; + let mut result = Vec::with_capacity(n_elements); + + for _ in 0..n_blocks { + let mut scale_buf = [0u8; 2]; + r.read_exact(&mut scale_buf).map_err(|e| e.to_string())?; + let scale = f16_to_f32(u16::from_le_bytes(scale_buf)); + + let mut nibbles = [0u8; 16]; + r.read_exact(&mut nibbles).map_err(|e| e.to_string())?; + + for &byte in &nibbles { + let lo = (byte & 0x0F) as i8 - 8; + let hi = ((byte >> 4) & 0x0F) as i8 - 8; + result.push(lo as f32 * scale); + result.push(hi as f32 * scale); + } + } + + result.truncate(n_elements); + Ok(result) +} + +/// Dequantize Q4_K: super-blocks of 256 elements. +/// +/// Q4_K block layout (144 bytes for 256 elements): +/// - 2 bytes: d (f16 scale) +/// - 2 bytes: dmin (f16 min) +/// - 12 bytes: scales (6-bit per sub-block, packed) +/// - 128 bytes: 256 4-bit quantized values (nibbles) +fn dequantize_q4_k(r: &mut R, n_elements: usize) -> Result, String> { + let block_size = 256; + let n_blocks = (n_elements + block_size - 1) / block_size; + let mut result = Vec::with_capacity(n_elements); + + for _ in 0..n_blocks { + // Read d and dmin (f16) + let mut d_buf = [0u8; 2]; + let mut dmin_buf = [0u8; 2]; + r.read_exact(&mut d_buf).map_err(|e| e.to_string())?; + r.read_exact(&mut dmin_buf).map_err(|e| e.to_string())?; + let d = f16_to_f32(u16::from_le_bytes(d_buf)); + let dmin = f16_to_f32(u16::from_le_bytes(dmin_buf)); + + // Read scales (12 bytes = 8 sub-block scales + 8 sub-block mins, 6-bit packed) + let mut scales_raw = [0u8; 12]; + r.read_exact(&mut scales_raw).map_err(|e| e.to_string())?; + + // Decode 8 scale/min pairs from 12 bytes (6 bits each) + let mut sc = [0u8; 8]; + let mut mn = [0u8; 8]; + for i in 0..4 { + sc[i] = scales_raw[i] & 0x3F; + mn[i] = scales_raw[i + 4] & 0x3F; + sc[i + 4] = ((scales_raw[i + 8] & 0x0F) << 2) | (scales_raw[i] >> 6); + mn[i + 4] = ((scales_raw[i + 8] >> 4) << 2) | (scales_raw[i + 4] >> 6); + } + + // Read 128 bytes of nibbles (256 4-bit values) + let mut nibbles = [0u8; 128]; + r.read_exact(&mut nibbles).map_err(|e| e.to_string())?; + + // Dequantize: each sub-block of 32 elements + for j in 0..8 { + let sub_d = d * sc[j] as f32; + let sub_m = dmin * mn[j] as f32; + let nib_offset = j * 16; + for k in 0..16 { + let byte = nibbles[nib_offset + k]; + let lo = (byte & 0x0F) as f32; + let hi = ((byte >> 4) & 0x0F) as f32; + result.push(lo * sub_d - sub_m); + result.push(hi * sub_d - sub_m); + } + } + } + + result.truncate(n_elements); + Ok(result) +} + /// Convert f16 bit pattern to f32. fn f16_to_f32(bits: u16) -> f32 { let sign = ((bits >> 15) & 1) as u32; diff --git a/src/hpc/gpt2/api.rs b/src/hpc/gpt2/api.rs new file mode 100644 index 00000000..677d7fb3 --- /dev/null +++ b/src/hpc/gpt2/api.rs @@ -0,0 +1,252 @@ +//! OpenAI-compatible API types for GPT-2 inference. +//! +//! Provides request/response structs matching the OpenAI API surface: +//! - `/v1/completions` — text completion +//! - `/v1/embeddings` — token embeddings via wte +//! - `/v1/models` — model listing +//! +//! These types are transport-agnostic — they serialize/deserialize +//! but don't depend on any HTTP framework. + +use super::inference::{GeneratedToken, Gpt2Engine}; +use super::weights::*; + +// ============================================================================ +// /v1/completions +// ============================================================================ + +/// Request body for /v1/completions. +#[derive(Clone, Debug)] +pub struct CompletionRequest { + /// Model name (ignored — we only have gpt2). + pub model: String, + /// Input text prompt (will be tokenized externally). + pub prompt_tokens: Vec, + /// Maximum tokens to generate. + pub max_tokens: usize, + /// Sampling temperature (1.0 = greedy effective). + pub temperature: f32, + /// Stop token ID (default: 50256 = <|endoftext|>). + pub stop_token: Option, +} + +impl Default for CompletionRequest { + fn default() -> Self { + Self { + model: "gpt2".into(), + prompt_tokens: Vec::new(), + max_tokens: 128, + temperature: 1.0, + stop_token: Some(50256), + } + } +} + +/// Single completion choice. +#[derive(Clone, Debug)] +pub struct CompletionChoice { + pub index: usize, + pub tokens: Vec, + pub finish_reason: FinishReason, +} + +/// Why generation stopped. +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum FinishReason { + Stop, + Length, +} + +/// Response body for /v1/completions. +#[derive(Clone, Debug)] +pub struct CompletionResponse { + pub id: String, + pub model: String, + pub choices: Vec, + pub usage: Usage, +} + +/// Token usage statistics. +#[derive(Clone, Debug, Default)] +pub struct Usage { + pub prompt_tokens: usize, + pub completion_tokens: usize, + pub total_tokens: usize, +} + +// ============================================================================ +// /v1/embeddings +// ============================================================================ + +/// Request body for /v1/embeddings. +#[derive(Clone, Debug)] +pub struct EmbeddingRequest { + pub model: String, + /// Token IDs to embed (one embedding per token). + pub input_tokens: Vec, +} + +/// Single embedding result. +#[derive(Clone, Debug)] +pub struct EmbeddingData { + pub index: usize, + pub embedding: Vec, +} + +/// Response body for /v1/embeddings. +#[derive(Clone, Debug)] +pub struct EmbeddingResponse { + pub model: String, + pub data: Vec, + pub usage: Usage, +} + +// ============================================================================ +// /v1/models +// ============================================================================ + +/// Model info for /v1/models. +#[derive(Clone, Debug)] +pub struct ModelInfo { + pub id: String, + pub owned_by: String, + pub vocab_size: usize, + pub embed_dim: usize, + pub num_layers: usize, + pub num_heads: usize, + pub max_seq_len: usize, +} + +impl ModelInfo { + /// GPT-2 small (124M) model info. + pub fn gpt2_small() -> Self { + Self { + id: "gpt2".into(), + owned_by: "adaworldapi".into(), + vocab_size: VOCAB_SIZE, + embed_dim: EMBED_DIM, + num_layers: NUM_LAYERS, + num_heads: NUM_HEADS, + max_seq_len: MAX_SEQ_LEN, + } + } +} + +// ============================================================================ +// Engine wrapper — stateless API over stateful engine +// ============================================================================ + +/// Stateless API wrapper around Gpt2Engine. +/// Handles request→response conversion. +pub struct Gpt2Api { + engine: Gpt2Engine, + request_counter: u64, +} + +impl Gpt2Api { + /// Create from pre-loaded weights. + pub fn new(weights: Gpt2Weights) -> Self { + Self { + engine: Gpt2Engine::new(weights), + request_counter: 0, + } + } + + /// /v1/completions handler. + pub fn complete(&mut self, req: &CompletionRequest) -> CompletionResponse { + self.request_counter += 1; + + let generated = self.engine.generate( + &req.prompt_tokens, + req.max_tokens, + req.temperature, + ); + + let finish_reason = if generated.len() < req.max_tokens { + FinishReason::Stop + } else { + FinishReason::Length + }; + + let completion_tokens = generated.len(); + let prompt_tokens = req.prompt_tokens.len(); + + CompletionResponse { + id: format!("cmpl-{}", self.request_counter), + model: "gpt2".into(), + choices: vec![CompletionChoice { + index: 0, + tokens: generated, + finish_reason, + }], + usage: Usage { + prompt_tokens, + completion_tokens, + total_tokens: prompt_tokens + completion_tokens, + }, + } + } + + /// /v1/embeddings handler — returns wte embeddings for token IDs. + pub fn embed(&self, req: &EmbeddingRequest) -> EmbeddingResponse { + let mut data = Vec::with_capacity(req.input_tokens.len()); + + for (idx, &token_id) in req.input_tokens.iter().enumerate() { + let offset = token_id as usize * EMBED_DIM; + let embedding = self.engine.weights().wte[offset..offset + EMBED_DIM].to_vec(); + data.push(EmbeddingData { + index: idx, + embedding, + }); + } + + EmbeddingResponse { + model: "gpt2".into(), + data, + usage: Usage { + prompt_tokens: req.input_tokens.len(), + completion_tokens: 0, + total_tokens: req.input_tokens.len(), + }, + } + } + + /// /v1/models handler. + pub fn model_info(&self) -> ModelInfo { + ModelInfo::gpt2_small() + } + + /// Access the underlying engine (for advanced usage). + pub fn engine_mut(&mut self) -> &mut Gpt2Engine { + &mut self.engine + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_model_info() { + let info = ModelInfo::gpt2_small(); + assert_eq!(info.vocab_size, 50257); + assert_eq!(info.embed_dim, 768); + assert_eq!(info.num_layers, 12); + assert_eq!(info.num_heads, 12); + assert_eq!(info.max_seq_len, 1024); + } + + #[test] + fn test_completion_request_default() { + let req = CompletionRequest::default(); + assert_eq!(req.max_tokens, 128); + assert_eq!(req.temperature, 1.0); + assert_eq!(req.stop_token, Some(50256)); + } + + #[test] + fn test_finish_reason_variants() { + assert_eq!(FinishReason::Stop, FinishReason::Stop); + assert_ne!(FinishReason::Stop, FinishReason::Length); + } +} diff --git a/src/hpc/gpt2/inference.rs b/src/hpc/gpt2/inference.rs new file mode 100644 index 00000000..2c828ec4 --- /dev/null +++ b/src/hpc/gpt2/inference.rs @@ -0,0 +1,642 @@ +//! GPT-2 inference engine — forward pass + generation loop. +//! +//! All transcendental ops use `crate::simd::F32x16`. +//! LayerNorm, GELU, Softmax — all SIMD-accelerated. +//! +//! # Tensor Codec Integration +//! +//! Wired through the full bgz17/HHTL/CausalEdge64 stack: +//! - **AttentionTable**: Palette-based O(1) approximate attention scores +//! from `jina::runtime::GPT2` (256×256 precomputed distances). +//! - **CausalEdge64**: Every attention head emits SPO causal edges with +//! NARS truth values. Accumulated during generation for causal reasoning. +//! - **Base17 embeddings**: Available via `jina::runtime::GPT2` for O(1) +//! token similarity (HHTL cascade: HEEL → LEAF). + +use super::weights::*; +use crate::hpc::jina::causal; +use crate::hpc::jina::runtime; +use crate::simd::F32x16; + +/// A generated token with its probability. +#[derive(Clone, Debug)] +pub struct GeneratedToken { + pub token_id: u32, + pub logprob: f32, +} + +/// CausalEdge64 emitted during attention. +/// Each edge encodes: which token attended to which, with what strength. +#[derive(Clone, Debug)] +pub struct AttentionEdge { + /// Transformer layer that produced this edge. + pub layer: u8, + /// Attention head index. + pub head: u8, + /// The packed CausalEdge64 (subject=query token, predicate=head, object=key token). + pub edge: u64, +} + +/// GPT-2 inference engine. +pub struct Gpt2Engine { + weights: Gpt2Weights, + /// KV cache for autoregressive generation. + kv_cache: Vec, + /// Current sequence length. + seq_len: usize, + /// Token IDs seen so far (for palette lookups). + token_history: Vec, + /// Accumulated causal edges from attention patterns. + pub causal_edges: Vec, + /// Whether to use AttentionTable approximation for attention scores. + pub use_attention_table: bool, + /// Whether to emit CausalEdge64 from attention patterns. + pub emit_causal_edges: bool, +} + +/// Key-Value cache for one layer. +#[derive(Clone)] +struct KvCache { + /// Cached keys: [seq_len, embed_dim] + keys: Vec, + /// Cached values: [seq_len, embed_dim] + values: Vec, +} + +impl Gpt2Engine { + /// Create engine from weights. + pub fn new(weights: Gpt2Weights) -> Self { + let kv_cache = (0..NUM_LAYERS) + .map(|_| KvCache { + keys: Vec::with_capacity(MAX_SEQ_LEN * EMBED_DIM), + values: Vec::with_capacity(MAX_SEQ_LEN * EMBED_DIM), + }) + .collect(); + Self { + weights, + kv_cache, + seq_len: 0, + token_history: Vec::with_capacity(MAX_SEQ_LEN), + causal_edges: Vec::new(), + use_attention_table: false, + emit_causal_edges: false, + } + } + + /// Access weights (for embedding lookups). + pub fn weights(&self) -> &Gpt2Weights { + &self.weights + } + + /// Reset KV cache (new conversation). + pub fn reset(&mut self) { + for kv in &mut self.kv_cache { + kv.keys.clear(); + kv.values.clear(); + } + self.seq_len = 0; + self.token_history.clear(); + self.causal_edges.clear(); + } + + /// Get HHTL cascade distance between two tokens via bgz17 Base17 palette. + /// Uses the precomputed GPT2 runtime from `jina::runtime::GPT2`. + #[inline] + pub fn token_similarity(&self, token_a: u32, token_b: u32) -> f32 { + let rt = &*runtime::GPT2; + rt.heel_similarity(token_a as usize, token_b as usize) + } + + /// Get Base17 L1 distance between two tokens (LEAF level, full precision). + #[inline] + pub fn token_distance_leaf(&self, token_a: u32, token_b: u32) -> u32 { + let rt = &*runtime::GPT2; + rt.leaf_distance(token_a as usize, token_b as usize) + } + + /// Get HHTL cascade distance with automatic level selection. + #[inline] + pub fn token_distance_cascade(&self, token_a: u32, token_b: u32) -> (u32, runtime::HhtlLevel) { + let rt = &*runtime::GPT2; + rt.cascade_distance(token_a as usize, token_b as usize) + } + + /// Get 6-byte CAM-PQ fingerprint for a token. + #[inline] + pub fn token_fingerprint(&self, token_id: u32) -> [u8; 6] { + let rt = &*runtime::GPT2; + rt.cam_fingerprint(token_id as usize) + } + + /// Forward pass for one token → logits over vocabulary. + /// + /// Uses KV cache for O(seq_len) attention instead of O(seq_len²). + /// When `emit_causal_edges` is true, attention patterns are packed + /// as CausalEdge64 with NARS truth values. + pub fn forward(&mut self, token_id: u32) -> Vec { + let pos = self.seq_len; + assert!(pos < MAX_SEQ_LEN, "sequence too long"); + self.token_history.push(token_id); + + // Embedding: wte[token] + wpe[position] + let mut hidden = vec![0.0f32; EMBED_DIM]; + let wte_offset = token_id as usize * EMBED_DIM; + let wpe_offset = pos * EMBED_DIM; + for i in 0..EMBED_DIM { + hidden[i] = self.weights.wte[wte_offset + i] + self.weights.wpe[wpe_offset + i]; + } + + // 12 transformer layers + for layer_idx in 0..NUM_LAYERS { + hidden = self.transformer_layer(layer_idx, &hidden); + } + + // Final layer norm + layer_norm_simd(&mut hidden, &self.weights.ln_f_weight, &self.weights.ln_f_bias); + + // Logits: hidden @ wte.T (weight tying) + let mut logits = vec![0.0f32; VOCAB_SIZE]; + let chunks = EMBED_DIM / 16; + for v in 0..VOCAB_SIZE { + let wte_off = v * EMBED_DIM; + let mut acc = F32x16::splat(0.0); + for c in 0..chunks { + let off = c * 16; + let vh = F32x16::from_slice(&hidden[off..off + 16]); + let vw = F32x16::from_slice(&self.weights.wte[wte_off + off..wte_off + off + 16]); + acc = vh.mul_add(vw, acc); + } + logits[v] = acc.reduce_sum(); + } + + self.seq_len += 1; + logits + } + + /// One transformer layer: attention + MLP with residuals. + fn transformer_layer(&mut self, layer_idx: usize, input: &[f32]) -> Vec { + // Clone LayerNorm params before mutable borrow on self (KV cache). + let ln1_w = self.weights.layers[layer_idx].ln1_weight.clone(); + let ln1_b = self.weights.layers[layer_idx].ln1_bias.clone(); + let ln2_w = self.weights.layers[layer_idx].ln2_weight.clone(); + let ln2_b = self.weights.layers[layer_idx].ln2_bias.clone(); + + // Pre-attention LayerNorm + let mut normed = input.to_vec(); + layer_norm_simd(&mut normed, &ln1_w, &ln1_b); + + // Attention: Q/K/V projection → scaled dot-product → output + let attn_out = self.multi_head_attention(layer_idx, &normed); + + // Residual connection + let mut hidden: Vec = input.iter().zip(&attn_out).map(|(a, b)| a + b).collect(); + + // Pre-MLP LayerNorm + let mut normed2 = hidden.clone(); + layer_norm_simd(&mut normed2, &ln2_w, &ln2_b); + + // MLP: fc → GELU → proj + let mlp_out = self.mlp(layer_idx, &normed2); + + // Residual connection + for i in 0..EMBED_DIM { + hidden[i] += mlp_out[i]; + } + + hidden + } + + /// Multi-head self-attention with KV cache. + /// + /// Integration points: + /// - **AttentionTable**: When `use_attention_table` is set, palette-based + /// similarity biases the attention scores (HEEL-level O(1) lookup). + /// - **CausalEdge64**: When `emit_causal_edges` is set, top attention + /// weights are packed as SPO edges with NARS truth values. + fn multi_head_attention(&mut self, layer_idx: usize, input: &[f32]) -> Vec { + let layer = &self.weights.layers[layer_idx]; + + // Q/K/V projection: input[768] × weight[768, 2304] + bias[2304] + let mut qkv = vec![0.0f32; 3 * EMBED_DIM]; // [Q(768), K(768), V(768)] + matmul_vec_simd(input, &layer.attn_qkv_weight, &layer.attn_qkv_bias, &mut qkv, EMBED_DIM, 3 * EMBED_DIM); + + let q = &qkv[..EMBED_DIM]; + let k = &qkv[EMBED_DIM..2 * EMBED_DIM]; + let v = &qkv[2 * EMBED_DIM..3 * EMBED_DIM]; + + // Append K, V to cache + self.kv_cache[layer_idx].keys.extend_from_slice(k); + self.kv_cache[layer_idx].values.extend_from_slice(v); + + let seq_len = self.seq_len + 1; // including current token + let current_token = *self.token_history.last().unwrap_or(&0); + let use_attn_table = self.use_attention_table; + let emit_edges = self.emit_causal_edges; + + // Per-head attention + let mut output = vec![0.0f32; EMBED_DIM]; + let scale = 1.0 / (HEAD_DIM as f32).sqrt(); + + // Lazy-init GPT2 palette runtime for AttentionTable / CausalEdge64 + let rt = if use_attn_table || emit_edges { + Some(&*runtime::GPT2) + } else { + None + }; + + for head in 0..NUM_HEADS { + let h_offset = head * HEAD_DIM; + + // Compute attention scores: Q[head] · K[head]^T for all cached positions + let mut scores = vec![0.0f32; seq_len]; + for t in 0..seq_len { + let k_offset = t * EMBED_DIM + h_offset; + let mut dot = 0.0f32; + for d in 0..HEAD_DIM { + dot += q[h_offset + d] * self.kv_cache[layer_idx].keys[k_offset + d]; + } + scores[t] = dot * scale; + + // AttentionTable bias: blend palette-based similarity into score. + // This is the "compiled attention" path — the 256×256 palette + // distance table provides O(1) semantic similarity. + if let Some(rt) = rt { + if use_attn_table && t < self.token_history.len() { + let key_token = self.token_history[t]; + let palette_sim = rt.heel_similarity( + current_token as usize, + key_token as usize, + ); + // Blend: 90% matmul score + 10% palette shortcut + scores[t] = scores[t] * 0.9 + palette_sim * 0.1 * scale; + } + } + } + + // Causal mask: already enforced by cache length + softmax_simd(&mut scores); + + // Emit CausalEdge64 for significant attention weights. + // S=current_token, P=head (via palette), O=attended_token. + if emit_edges { + if let Some(rt) = rt { + for t in 0..seq_len { + if scores[t] > 0.05 && t < self.token_history.len() { + let key_token = self.token_history[t]; + let edge = rt.pack_spo_edge( + current_token as usize, + head, // predicate = attention head + key_token as usize, + scores[t], // frequency = attention weight + 0.3, // initial confidence (low) + self.seq_len as u16, // temporal position + ); + self.causal_edges.push(AttentionEdge { + layer: layer_idx as u8, + head: head as u8, + edge, + }); + } + } + } + } + + // Weighted sum of values + for t in 0..seq_len { + let v_offset = t * EMBED_DIM + h_offset; + let w = scores[t]; + for d in 0..HEAD_DIM { + output[h_offset + d] += w * self.kv_cache[layer_idx].values[v_offset + d]; + } + } + } + + // Output projection: output[768] × weight[768, 768] + bias[768] + let mut projected = vec![0.0f32; EMBED_DIM]; + matmul_vec_simd(&output, &layer.attn_out_weight, &layer.attn_out_bias, &mut projected, EMBED_DIM, EMBED_DIM); + + projected + } + + /// MLP: fc[768→3072] → GELU → proj[3072→768]. + fn mlp(&self, layer_idx: usize, input: &[f32]) -> Vec { + let layer = &self.weights.layers[layer_idx]; + + // FC: input[768] × weight[768, 3072] + bias[3072] + let mut fc_out = vec![0.0f32; MLP_DIM]; + matmul_vec_simd(input, &layer.mlp_fc_weight, &layer.mlp_fc_bias, &mut fc_out, EMBED_DIM, MLP_DIM); + + // GELU activation (via SIMD) + gelu_simd(&mut fc_out); + + // Proj: fc_out[3072] × weight[3072, 768] + bias[768] + let mut output = vec![0.0f32; EMBED_DIM]; + matmul_vec_simd(&fc_out, &layer.mlp_proj_weight, &layer.mlp_proj_bias, &mut output, MLP_DIM, EMBED_DIM); + + output + } + + /// Generate tokens autoregressively. + pub fn generate(&mut self, prompt_tokens: &[u32], max_new_tokens: usize, temperature: f32) -> Vec { + self.reset(); + let mut generated = Vec::new(); + + // Process prompt (fill KV cache) + let mut last_logits = vec![0.0f32; VOCAB_SIZE]; + for &token in prompt_tokens { + last_logits = self.forward(token); + } + + // Generate new tokens + for _ in 0..max_new_tokens { + // Apply temperature + if temperature != 1.0 { + for l in &mut last_logits { + *l /= temperature; + } + } + + // Greedy: argmax + let mut best_id = 0u32; + let mut best_logit = f32::NEG_INFINITY; + for (i, &l) in last_logits.iter().enumerate() { + if l > best_logit { + best_logit = l; + best_id = i as u32; + } + } + + // End of text + if best_id == 50256 { + break; + } + + generated.push(GeneratedToken { + token_id: best_id, + logprob: best_logit, + }); + + // Feed generated token back + last_logits = self.forward(best_id); + } + + generated + } +} + +// ============================================================================ +// SIMD-accelerated primitives (all use crate::simd::F32x16) +// ============================================================================ + +/// Layer normalization with F32x16 SIMD. +fn layer_norm_simd(x: &mut [f32], weight: &[f32], bias: &[f32]) { + let n = x.len(); + + // Mean (SIMD) + let chunks = n / 16; + let mut sum_acc = F32x16::splat(0.0); + for c in 0..chunks { + let off = c * 16; + sum_acc = sum_acc + F32x16::from_slice(&x[off..off + 16]); + } + let mut mean = sum_acc.reduce_sum(); + for i in (chunks * 16)..n { + mean += x[i]; + } + mean /= n as f32; + + // Variance (SIMD) + let mean_vec = F32x16::splat(mean); + let mut var_acc = F32x16::splat(0.0); + for c in 0..chunks { + let off = c * 16; + let diff = F32x16::from_slice(&x[off..off + 16]) - mean_vec; + var_acc = diff.mul_add(diff, var_acc); + } + let mut var = var_acc.reduce_sum(); + for i in (chunks * 16)..n { + let d = x[i] - mean; + var += d * d; + } + var /= n as f32; + let inv_std = 1.0 / (var + 1e-5).sqrt(); + + // Normalize + scale + shift (SIMD) + let inv_std_vec = F32x16::splat(inv_std); + for c in 0..chunks { + let off = c * 16; + let val = F32x16::from_slice(&x[off..off + 16]); + let w = F32x16::from_slice(&weight[off..off + 16]); + let b = F32x16::from_slice(&bias[off..off + 16]); + let normed = (val - mean_vec) * inv_std_vec; + let result = normed * w + b; + result.copy_to_slice(&mut x[off..off + 16]); + } + for i in (chunks * 16)..n { + x[i] = (x[i] - mean) * inv_std * weight[i] + bias[i]; + } +} + +/// GELU activation: x * 0.5 * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x³))) +fn gelu_simd(x: &mut [f32]) { + let n = x.len(); + let sqrt_2_over_pi = F32x16::splat(0.7978845608); // sqrt(2/π) + let coeff = F32x16::splat(0.044715); + let half = F32x16::splat(0.5); + let one = F32x16::splat(1.0); + + let chunks = n / 16; + for c in 0..chunks { + let off = c * 16; + let v = F32x16::from_slice(&x[off..off + 16]); + let v3 = v * v * v; + let inner = sqrt_2_over_pi * (v + coeff * v3); + // tanh approximation via exp: tanh(x) = (exp(2x) - 1) / (exp(2x) + 1) + let two_inner = inner + inner; + let exp_2x = crate::simd::simd_exp_f32(two_inner); + let tanh_v = (exp_2x - one) / (exp_2x + one); + let result = v * half * (one + tanh_v); + result.copy_to_slice(&mut x[off..off + 16]); + } + for i in (chunks * 16)..n { + let v = x[i]; + let inner = 0.7978845608 * (v + 0.044715 * v * v * v); + let tanh_v = inner.tanh(); + x[i] = v * 0.5 * (1.0 + tanh_v); + } +} + +/// Softmax with numerical stability (SIMD). +fn softmax_simd(x: &mut [f32]) { + // Find max (for numerical stability) + let mut max_val = f32::NEG_INFINITY; + for &v in x.iter() { + if v > max_val { + max_val = v; + } + } + + // exp(x - max) and sum + let mut sum = 0.0f32; + for v in x.iter_mut() { + *v = (*v - max_val).exp(); + sum += *v; + } + + // Normalize + let inv_sum = 1.0 / sum; + for v in x.iter_mut() { + *v *= inv_sum; + } +} + +/// Matrix-vector multiply: out = input @ weight^T + bias. +/// Weight stored as [input_dim, output_dim] (row-major, transposed access). +/// SIMD accelerated for the dot product. +/// Matrix-vector multiply: out = input @ weight + bias. +/// Weight is PRE-TRANSPOSED to [out_dim, in_dim] for contiguous SIMD access. +/// Each output element reads a contiguous row of in_dim floats. +fn matmul_vec_simd(input: &[f32], weight: &[f32], bias: &[f32], output: &mut [f32], in_dim: usize, out_dim: usize) { + let chunks = in_dim / 16; + let remainder = in_dim % 16; + + for o in 0..out_dim { + let row_offset = o * in_dim; + let mut acc = F32x16::splat(0.0); + for c in 0..chunks { + let off = c * 16; + let vi = F32x16::from_slice(&input[off..off + 16]); + let vw = F32x16::from_slice(&weight[row_offset + off..row_offset + off + 16]); + acc = vi.mul_add(vw, acc); + } + let mut dot = acc.reduce_sum(); + // Scalar tail + let tail_start = chunks * 16; + for i in 0..remainder { + dot += input[tail_start + i] * weight[row_offset + tail_start + i]; + } + output[o] = dot + bias[o]; + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_layer_norm_identity() { + let mut x = vec![1.0, 2.0, 3.0, 4.0]; + let w = vec![1.0; 4]; + let b = vec![0.0; 4]; + layer_norm_simd(&mut x, &w, &b); + // After normalization: mean≈0, std≈1 + let mean: f32 = x.iter().sum::() / 4.0; + assert!(mean.abs() < 0.01, "mean should be ~0, got {}", mean); + } + + #[test] + fn test_gelu_zero() { + let mut x = vec![0.0f32; 16]; + gelu_simd(&mut x); + assert!(x[0].abs() < 0.01, "GELU(0) should be ~0"); + } + + #[test] + fn test_gelu_positive() { + let mut x = vec![2.0f32; 16]; + gelu_simd(&mut x); + // GELU(2) ≈ 1.9545 + assert!((x[0] - 1.9545).abs() < 0.01, "GELU(2) ≈ 1.95, got {}", x[0]); + } + + #[test] + fn test_softmax_sums_to_one() { + let mut x = vec![1.0, 2.0, 3.0, 4.0]; + softmax_simd(&mut x); + let sum: f32 = x.iter().sum(); + assert!((sum - 1.0).abs() < 1e-5, "softmax should sum to 1.0, got {}", sum); + } + + #[test] + fn test_softmax_argmax_preserved() { + let mut x = vec![1.0, 5.0, 2.0, 3.0]; + softmax_simd(&mut x); + // Index 1 (value 5.0) should have highest probability + assert!(x[1] > x[0] && x[1] > x[2] && x[1] > x[3]); + } + + // ===== Tensor codec integration tests ===== + + #[test] + fn test_token_similarity_self() { + // Token similarity to itself should be ~1.0 (via GPT2 palette) + let engine = Gpt2Engine::new(Gpt2Weights { + wte: vec![0.0; VOCAB_SIZE * EMBED_DIM], + wpe: vec![0.0; MAX_SEQ_LEN * EMBED_DIM], + layers: Vec::new(), + ln_f_weight: vec![1.0; EMBED_DIM], + ln_f_bias: vec![0.0; EMBED_DIM], + }); + let sim = engine.token_similarity(0, 0); + assert!((sim - 1.0).abs() < 0.01, "self-similarity should be ~1.0, got {}", sim); + } + + #[test] + fn test_token_similarity_different() { + let engine = Gpt2Engine::new(Gpt2Weights { + wte: vec![0.0; VOCAB_SIZE * EMBED_DIM], + wpe: vec![0.0; MAX_SEQ_LEN * EMBED_DIM], + layers: Vec::new(), + ln_f_weight: vec![1.0; EMBED_DIM], + ln_f_bias: vec![0.0; EMBED_DIM], + }); + let sim = engine.token_similarity(100, 50000); + assert!(sim < 1.0, "different tokens should have similarity < 1.0"); + } + + #[test] + fn test_token_fingerprint_6bytes() { + let engine = Gpt2Engine::new(Gpt2Weights { + wte: vec![0.0; VOCAB_SIZE * EMBED_DIM], + wpe: vec![0.0; MAX_SEQ_LEN * EMBED_DIM], + layers: Vec::new(), + ln_f_weight: vec![1.0; EMBED_DIM], + ln_f_bias: vec![0.0; EMBED_DIM], + }); + let fp = engine.token_fingerprint(1000); + assert_eq!(fp.len(), 6); + // First byte is palette index + let rt = &*runtime::GPT2; + assert_eq!(fp[0], rt.palette.palette_index(1000)); + } + + #[test] + fn test_cascade_distance_levels() { + let engine = Gpt2Engine::new(Gpt2Weights { + wte: vec![0.0; VOCAB_SIZE * EMBED_DIM], + wpe: vec![0.0; MAX_SEQ_LEN * EMBED_DIM], + layers: Vec::new(), + ln_f_weight: vec![1.0; EMBED_DIM], + ln_f_bias: vec![0.0; EMBED_DIM], + }); + // Self-distance should resolve at HEEL level + let (d, level) = engine.token_distance_cascade(0, 0); + assert_eq!(d, 0); + assert_eq!(level, runtime::HhtlLevel::Heel); + } + + #[test] + fn test_causal_edge_emission_flag() { + // Verify that emit_causal_edges flag controls edge emission + let mut engine = Gpt2Engine::new(Gpt2Weights { + wte: vec![0.0; VOCAB_SIZE * EMBED_DIM], + wpe: vec![0.0; MAX_SEQ_LEN * EMBED_DIM], + layers: Vec::new(), + ln_f_weight: vec![1.0; EMBED_DIM], + ln_f_bias: vec![0.0; EMBED_DIM], + }); + assert!(!engine.emit_causal_edges, "should be off by default"); + assert!(!engine.use_attention_table, "should be off by default"); + assert!(engine.causal_edges.is_empty()); + } +} diff --git a/src/hpc/gpt2/mod.rs b/src/hpc/gpt2/mod.rs new file mode 100644 index 00000000..fad2fc3d --- /dev/null +++ b/src/hpc/gpt2/mod.rs @@ -0,0 +1,30 @@ +//! GPT-2 inference engine — autoregressive text generation on CPU. +//! +//! Full GPT-2 (124M) running through: +//! - `crate::simd::F32x16` for all transcendental ops +//! - Base17 palette for O(1) embedding lookup +//! - Optional AttentionTable for O(1) attention (when compiled) +//! - CausalEdge64 for causal reasoning on generated tokens +//! +//! # Architecture +//! +//! ```text +//! Input text → BPE tokenize → token IDs +//! → wte[token_id] (embedding lookup) + wpe[position] (positional) +//! → 12 transformer layers: +//! LayerNorm → MultiHeadAttention → Residual +//! LayerNorm → MLP (GELU) → Residual +//! → LayerNorm → logits → argmax/sample → next token +//! → repeat until <|endoftext|> or max_tokens +//! ``` +//! +//! # Speed Target +//! +//! GPT-2 small: 768D, 12 layers, 12 heads, 50K vocab. +//! Full matmul path: ~50ms per token on single CPU core. +//! With AttentionTable: ~5ms per token (10× faster). +//! With SIMD exp/sigmoid: ~30% faster transcendentals. + +pub mod weights; +pub mod inference; +pub mod api; diff --git a/src/hpc/gpt2/weights.rs b/src/hpc/gpt2/weights.rs new file mode 100644 index 00000000..f3ed4542 --- /dev/null +++ b/src/hpc/gpt2/weights.rs @@ -0,0 +1,215 @@ +//! GPT-2 weight loading from safetensors format. +//! +//! Loads ALL weights needed for inference, not just embeddings. +//! Each tensor stored as contiguous f32 arrays for SIMD access. + +use std::collections::HashMap; +use std::io::{Read, Seek, SeekFrom}; + +/// GPT-2 model configuration. +pub const VOCAB_SIZE: usize = 50257; +pub const EMBED_DIM: usize = 768; +pub const NUM_LAYERS: usize = 12; +pub const NUM_HEADS: usize = 12; +pub const HEAD_DIM: usize = EMBED_DIM / NUM_HEADS; // 64 +pub const MLP_DIM: usize = 3072; // 4 × EMBED_DIM +pub const MAX_SEQ_LEN: usize = 1024; + +/// All weights for one transformer layer. +#[derive(Clone)] +pub struct LayerWeights { + /// Attention layer norm: weight [768] + bias [768] + pub ln1_weight: Vec, + pub ln1_bias: Vec, + /// Combined Q/K/V projection: [768, 2304] + bias [2304] + pub attn_qkv_weight: Vec, + pub attn_qkv_bias: Vec, + /// Output projection: [768, 768] + bias [768] + pub attn_out_weight: Vec, + pub attn_out_bias: Vec, + /// MLP layer norm: weight [768] + bias [768] + pub ln2_weight: Vec, + pub ln2_bias: Vec, + /// MLP fc: [768, 3072] + bias [3072] + pub mlp_fc_weight: Vec, + pub mlp_fc_bias: Vec, + /// MLP proj: [3072, 768] + bias [768] + pub mlp_proj_weight: Vec, + pub mlp_proj_bias: Vec, +} + +/// Complete GPT-2 model weights. +#[derive(Clone)] +pub struct Gpt2Weights { + /// Token embedding: [50257, 768] + pub wte: Vec, + /// Position embedding: [1024, 768] + pub wpe: Vec, + /// Transformer layers + pub layers: Vec, + /// Final layer norm + pub ln_f_weight: Vec, + pub ln_f_bias: Vec, +} + +impl Gpt2Weights { + /// Load from a safetensors file via our memory-mapped reader. + /// + /// This reads ALL weights needed for inference (~500MB f32). + /// For production: weights would be quantized or compiled to AttentionTable. + pub fn from_safetensors(path: &std::path::Path) -> Result { + // Safetensors format: [header_size:u64_le][header_json][tensor_data] + let file = std::fs::read(path).map_err(|e| e.to_string())?; + + let header_size = u64::from_le_bytes([ + file[0], file[1], file[2], file[3], + file[4], file[5], file[6], file[7], + ]) as usize; + + let header_json = std::str::from_utf8(&file[8..8 + header_size]) + .map_err(|e| e.to_string())?; + + // Parse tensor metadata from JSON header + let data_start = 8 + header_size; + let tensors = parse_safetensors_header(header_json)?; + + let read_tensor = |name: &str| -> Result, String> { + let info = tensors.get(name) + .ok_or_else(|| format!("Missing tensor: {}", name))?; + let start = data_start + info.offset; + let end = start + info.size; + if end > file.len() { + return Err(format!("Tensor {} extends beyond file", name)); + } + Ok(file[start..end] + .chunks_exact(4) + .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]])) + .collect()) + }; + + let wte = read_tensor("wte.weight")?; + let wpe = read_tensor("wpe.weight")?; + let ln_f_weight = read_tensor("ln_f.weight")?; + let ln_f_bias = read_tensor("ln_f.bias")?; + + let mut layers = Vec::with_capacity(NUM_LAYERS); + for i in 0..NUM_LAYERS { + let prefix = format!("h.{}", i); + layers.push(LayerWeights { + ln1_weight: read_tensor(&format!("{}.ln_1.weight", prefix))?, + ln1_bias: read_tensor(&format!("{}.ln_1.bias", prefix))?, + attn_qkv_weight: read_tensor(&format!("{}.attn.c_attn.weight", prefix))?, + attn_qkv_bias: read_tensor(&format!("{}.attn.c_attn.bias", prefix))?, + attn_out_weight: read_tensor(&format!("{}.attn.c_proj.weight", prefix))?, + attn_out_bias: read_tensor(&format!("{}.attn.c_proj.bias", prefix))?, + ln2_weight: read_tensor(&format!("{}.ln_2.weight", prefix))?, + ln2_bias: read_tensor(&format!("{}.ln_2.bias", prefix))?, + mlp_fc_weight: read_tensor(&format!("{}.mlp.c_fc.weight", prefix))?, + mlp_fc_bias: read_tensor(&format!("{}.mlp.c_fc.bias", prefix))?, + mlp_proj_weight: read_tensor(&format!("{}.mlp.c_proj.weight", prefix))?, + mlp_proj_bias: read_tensor(&format!("{}.mlp.c_proj.bias", prefix))?, + }); + } + + let mut weights = Gpt2Weights { + wte, wpe, layers, ln_f_weight, ln_f_bias, + }; + weights.transpose_weights_for_simd(); + Ok(weights) + } + + /// Transpose all weight matrices from [in_dim, out_dim] to [out_dim, in_dim]. + /// After this, matmul can read weight rows contiguously for F32x16 SIMD. + fn transpose_weights_for_simd(&mut self) { + for layer in &mut self.layers { + transpose_matrix(&mut layer.attn_qkv_weight, EMBED_DIM, 3 * EMBED_DIM); + transpose_matrix(&mut layer.attn_out_weight, EMBED_DIM, EMBED_DIM); + transpose_matrix(&mut layer.mlp_fc_weight, EMBED_DIM, MLP_DIM); + transpose_matrix(&mut layer.mlp_proj_weight, MLP_DIM, EMBED_DIM); + } + } +} + +/// Transpose a [rows, cols] matrix in-place to [cols, rows]. +fn transpose_matrix(data: &mut Vec, rows: usize, cols: usize) { + assert_eq!(data.len(), rows * cols); + let mut transposed = vec![0.0f32; rows * cols]; + for r in 0..rows { + for c in 0..cols { + transposed[c * rows + r] = data[r * cols + c]; + } + } + *data = transposed; +} + +/// Tensor metadata from safetensors header. +struct TensorMeta { + offset: usize, + size: usize, +} + +/// Parse safetensors JSON header to get tensor offsets and sizes. +fn parse_safetensors_header(json: &str) -> Result, String> { + // Minimal JSON parser for safetensors header format: + // { "tensor_name": { "dtype": "F32", "shape": [...], "data_offsets": [start, end] }, ... } + let mut tensors = HashMap::new(); + + // Find each tensor entry + let mut pos = 0; + while let Some(key_start) = json[pos..].find('"') { + let key_start = pos + key_start + 1; + let key_end = match json[key_start..].find('"') { + Some(e) => key_start + e, + None => break, + }; + let key = &json[key_start..key_end]; + pos = key_end + 1; + + // Skip __metadata__ + if key == "__metadata__" { + if let Some(end) = json[pos..].find('}') { + pos += end + 1; + } + continue; + } + + // Find data_offsets + if let Some(offsets_start) = json[pos..].find("data_offsets") { + let search_start = pos + offsets_start; + if let Some(bracket_start) = json[search_start..].find('[') { + let arr_start = search_start + bracket_start + 1; + if let Some(bracket_end) = json[arr_start..].find(']') { + let arr = &json[arr_start..arr_start + bracket_end]; + let nums: Vec = arr.split(',') + .filter_map(|s| s.trim().parse().ok()) + .collect(); + if nums.len() == 2 { + tensors.insert(key.to_string(), TensorMeta { + offset: nums[0], + size: nums[1] - nums[0], + }); + } + } + } + } + + // Advance past the closing brace of this tensor's value + if let Some(brace) = json[pos..].find('}') { + pos += brace + 1; + } + } + + Ok(tensors) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_config_consistency() { + assert_eq!(EMBED_DIM, NUM_HEADS * HEAD_DIM); + assert_eq!(MLP_DIM, 4 * EMBED_DIM); + assert_eq!(VOCAB_SIZE, 50257); + } +} diff --git a/src/hpc/jina/mod.rs b/src/hpc/jina/mod.rs index cfa3c770..fbd4fd20 100644 --- a/src/hpc/jina/mod.rs +++ b/src/hpc/jina/mod.rs @@ -24,3 +24,4 @@ pub mod cache; pub mod codec; pub mod causal; +pub mod runtime; diff --git a/src/hpc/jina/runtime.rs b/src/hpc/jina/runtime.rs new file mode 100644 index 00000000..e7962039 --- /dev/null +++ b/src/hpc/jina/runtime.rs @@ -0,0 +1,298 @@ +//! Runtime loader: wire Base17 + palette caches through the full tensor codec. +//! +//! Connects the pre-computed weights to: +//! - HHTL cascade (HEEL/HIP/TWIG/LEAF levels) +//! - CAM-PQ style 6-byte fingerprints +//! - CausalEdge64 S/P/O palette indices +//! - SimilarityTable calibration (256-entry CDF) + +use super::cache::{load_base17_cache, load_palette_cache}; +use super::causal; +use super::codec::{Base17Token, JinaPalette, BASE_DIM, PALETTE_K}; +use std::sync::LazyLock; + +/// Embedded weight files (compiled into the binary via include_bytes!). +/// Zero file I/O at runtime — the weights ARE the binary. +static JINA_BASE17: &[u8] = include_bytes!("weights/jina_base17_20k.bin"); +static JINA_PALETTE: &[u8] = include_bytes!("weights/jina_palette_20k.bin"); +static GPT2_BASE17: &[u8] = include_bytes!("weights/gpt2_base17_50k.bin"); +static GPT2_PALETTE: &[u8] = include_bytes!("weights/gpt2_palette_50k.bin"); +static BERT_BASE17: &[u8] = include_bytes!("weights/bert_base17_30k.bin"); +static BERT_PALETTE: &[u8] = include_bytes!("weights/bert_palette_30k.bin"); + +/// Which model's weights to use. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum ModelSource { + /// Jina v4 text-retrieval (20K tokens, 2048D original). + JinaV4, + /// GPT-2 small (50K tokens, 768D original). Same BPE as Jina. + Gpt2, + /// BERT base uncased (30K tokens, 768D original). WordPiece tokenizer. + Bert, +} + +/// The full runtime: Base17 tokens + palette + distance table + HHTL cascade. +/// Loaded once via LazyLock. Zero cost after first access. +pub struct ModelRuntime { + /// Source model identifier. + pub source: ModelSource, + /// All token embeddings in Base17 format (34 bytes each). + pub tokens: Vec, + /// 256-entry palette with precomputed 256×256 distance table. + pub palette: JinaPalette, + /// SimilarityTable: 256-entry CDF calibration (distance → f32 [0,1]). + pub similarity: [f32; 256], +} + +impl ModelRuntime { + /// Load from embedded weight bytes. + fn load(source: ModelSource, base17_bytes: &[u8], palette_bytes: &[u8]) -> Self { + let tokens = load_base17_cache(&mut std::io::Cursor::new(base17_bytes)) + .expect("Failed to load Base17 cache"); + let palette = load_palette_cache(&mut std::io::Cursor::new(palette_bytes)) + .expect("Failed to load palette cache"); + + // Build SimilarityTable from the EXACT 256×256 distance distribution. + // This IS the bgz17 SimilarityTable pattern: empirical CDF → calibrated f32. + let similarity = build_similarity_table(&palette); + + ModelRuntime { + source, + tokens, + palette, + similarity, + } + } + + /// HHTL HEEL: palette index distance (1 byte per token, O(1)). + #[inline(always)] + pub fn heel_distance(&self, token_a: usize, token_b: usize) -> u16 { + self.palette.distance(token_a, token_b) + } + + /// HHTL HEEL: calibrated similarity via SimilarityTable [0.0, 1.0]. + #[inline(always)] + pub fn heel_similarity(&self, token_a: usize, token_b: usize) -> f32 { + let d = self.heel_distance(token_a, token_b) as usize; + self.similarity[d.min(255)] + } + + /// HHTL TWIG: Base17 L1 distance (34 bytes per token, full resolution). + #[inline(always)] + pub fn leaf_distance(&self, token_a: usize, token_b: usize) -> u32 { + self.tokens[token_a].l1(&self.tokens[token_b]) + } + + /// HHTL cascade: HEEL first, escalate to LEAF if needed. + /// Returns (distance, level_used). Stops as soon as ranking is confident. + #[inline] + pub fn cascade_distance(&self, token_a: usize, token_b: usize) -> (u32, HhtlLevel) { + let heel = self.heel_distance(token_a, token_b); + + // Trivial cases: same palette entry or very far apart + if heel == 0 { + return (0, HhtlLevel::Heel); + } + if heel > 500 { + return (heel as u32, HhtlLevel::Heel); + } + + // Ambiguous zone: escalate to LEAF for precision + let leaf = self.leaf_distance(token_a, token_b); + (leaf, HhtlLevel::Leaf) + } + + /// Pack two tokens + a predicate into a CausalEdge64. + #[inline] + pub fn pack_spo_edge( + &self, + subject_token: usize, + predicate_token: usize, + object_token: usize, + frequency: f32, + confidence: f32, + temporal: u16, + ) -> u64 { + causal::pack_edge( + self.palette.palette_index(subject_token), + self.palette.palette_index(predicate_token), + self.palette.palette_index(object_token), + frequency, + confidence, + 0b111, // full SPO Pearl mask + temporal, + ) + } + + /// CAM-PQ style 6-byte fingerprint: [palette_idx, base17_dim0..4]. + #[inline] + pub fn cam_fingerprint(&self, token: usize) -> [u8; 6] { + let pal = self.palette.palette_index(token); + let b17 = &self.tokens[token].dims; + [ + pal, + (b17[0].wrapping_shr(8)) as u8, // BRANCH: sign dimension (MSB of dim 0) + (b17[1].wrapping_shr(8)) as u8, // TWIG_A: dim 1 MSB + (b17[2].wrapping_shr(8)) as u8, // TWIG_B: dim 2 MSB + (b17[3].wrapping_shr(8)) as u8, // LEAF: dim 3 MSB + (b17[4].wrapping_shr(8)) as u8, // GAMMA: dim 4 MSB + ] + } + + /// Token count. + pub fn vocab_size(&self) -> usize { + self.tokens.len() + } +} + +/// HHTL cascade level that resolved the distance. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum HhtlLevel { + /// Palette-level distance (1 byte per token). + Heel, + /// Full Base17 L1 distance (34 bytes per token). + Leaf, +} + +/// Build SimilarityTable from the 256×256 palette distance distribution. +/// Empirical CDF: count how many pairs have distance ≤ d, normalize. +fn build_similarity_table(palette: &JinaPalette) -> [f32; 256] { + // Collect all pairwise distances + let mut all_distances = Vec::with_capacity(PALETTE_K * (PALETTE_K - 1) / 2); + for i in 0..PALETTE_K { + for j in (i + 1)..PALETTE_K { + all_distances.push(palette.distance_table[i][j] as u32); + } + } + all_distances.sort(); + + let n = all_distances.len() as f32; + let max_d = all_distances.last().copied().unwrap_or(1) as usize; + + // Build CDF: similarity[d] = 1.0 - (fraction of pairs with distance ≤ d) + let mut table = [0.0f32; 256]; + for bucket in 0..256 { + let threshold = if max_d > 0 { + (bucket as u64 * max_d as u64 / 255) as u32 + } else { + 0 + }; + let count = all_distances.partition_point(|&d| d <= threshold) as f32; + let cdf = count / n; + table[bucket] = 1.0 - cdf; // High distance = low similarity + } + table[0] = 1.0; // Self-distance = perfect similarity + + table +} + +// ============================================================================ +// Global LazyLock runtimes — loaded once, used forever +// ============================================================================ + +/// Jina v4 runtime (20K tokens). LazyLock: zero cost after first access. +pub static JINA: LazyLock = LazyLock::new(|| { + ModelRuntime::load(ModelSource::JinaV4, JINA_BASE17, JINA_PALETTE) +}); + +/// GPT-2 runtime (50K tokens). Same BPE as Jina → interoperable palettes. +pub static GPT2: LazyLock = LazyLock::new(|| { + ModelRuntime::load(ModelSource::Gpt2, GPT2_BASE17, GPT2_PALETTE) +}); + +/// BERT runtime (30K tokens). WordPiece tokenizer (different from GPT-2 BPE). +pub static BERT: LazyLock = LazyLock::new(|| { + ModelRuntime::load(ModelSource::Bert, BERT_BASE17, BERT_PALETTE) +}); + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_jina_runtime_loads() { + let rt = &*JINA; + assert_eq!(rt.source, ModelSource::JinaV4); + assert_eq!(rt.vocab_size(), 20000); + assert!((rt.similarity[0] - 1.0).abs() < 0.01, "self-similarity should be ~1.0"); + } + + #[test] + fn test_gpt2_runtime_loads() { + let rt = &*GPT2; + assert_eq!(rt.source, ModelSource::Gpt2); + assert_eq!(rt.vocab_size(), 50257); + } + + #[test] + fn test_bert_runtime_loads() { + let rt = &*BERT; + assert_eq!(rt.source, ModelSource::Bert); + assert_eq!(rt.vocab_size(), 30522); + } + + #[test] + fn test_heel_self_distance_zero() { + let rt = &*GPT2; + assert_eq!(rt.heel_distance(0, 0), 0); + assert!((rt.heel_similarity(0, 0) - 1.0).abs() < 0.01); + } + + #[test] + fn test_heel_symmetric() { + let rt = &*GPT2; + assert_eq!(rt.heel_distance(100, 200), rt.heel_distance(200, 100)); + } + + #[test] + fn test_cascade_trivial_same() { + let rt = &*JINA; + let (d, level) = rt.cascade_distance(0, 0); + assert_eq!(d, 0); + assert_eq!(level, HhtlLevel::Heel); + } + + #[test] + fn test_pack_spo_edge() { + let rt = &*GPT2; + let edge = rt.pack_spo_edge(100, 200, 300, 0.8, 0.6, 42); + assert_eq!(causal::edge_temporal(edge), 42); + assert!((causal::edge_freq(edge) - 0.8).abs() < 0.01); + } + + #[test] + fn test_cam_fingerprint() { + let rt = &*BERT; + let fp = rt.cam_fingerprint(1000); + // First byte is palette index + assert_eq!(fp[0], rt.palette.palette_index(1000)); + // Should be 6 bytes + assert_eq!(fp.len(), 6); + } + + #[test] + fn test_similarity_table_monotonic() { + let rt = &*GPT2; + // Similarity should generally decrease with bucket index + // (higher bucket = larger distance = lower similarity) + assert!(rt.similarity[0] >= rt.similarity[255]); + } + + #[test] + fn test_cross_model_palette_comparison() { + // GPT-2 and Jina share BPE — token 0 in both should be + // the same subword. Their palette indices may differ + // (different k-means runs) but the Base17 vectors should correlate. + let jina = &*JINA; + let gpt2 = &*GPT2; + + // Token 0 exists in both + let jina_fp = jina.cam_fingerprint(0); + let gpt2_fp = gpt2.cam_fingerprint(0); + + // They're from different models, so fingerprints may differ. + // But both should be valid 6-byte fingerprints. + assert_eq!(jina_fp.len(), 6); + assert_eq!(gpt2_fp.len(), 6); + } +} diff --git a/src/hpc/jina/weights/bert_base17_30k.bin b/src/hpc/jina/weights/bert_base17_30k.bin new file mode 100644 index 00000000..8e167198 Binary files /dev/null and b/src/hpc/jina/weights/bert_base17_30k.bin differ diff --git a/src/hpc/jina/weights/bert_palette_30k.bin b/src/hpc/jina/weights/bert_palette_30k.bin new file mode 100644 index 00000000..d8c59433 Binary files /dev/null and b/src/hpc/jina/weights/bert_palette_30k.bin differ diff --git a/src/hpc/jina/weights/gpt2_base17_50k.bin b/src/hpc/jina/weights/gpt2_base17_50k.bin new file mode 100644 index 00000000..3954c301 Binary files /dev/null and b/src/hpc/jina/weights/gpt2_base17_50k.bin differ diff --git a/src/hpc/jina/weights/gpt2_palette_50k.bin b/src/hpc/jina/weights/gpt2_palette_50k.bin new file mode 100644 index 00000000..5163e330 Binary files /dev/null and b/src/hpc/jina/weights/gpt2_palette_50k.bin differ diff --git a/src/hpc/mod.rs b/src/hpc/mod.rs index 090fe54e..4c34f4b6 100644 --- a/src/hpc/mod.rs +++ b/src/hpc/mod.rs @@ -170,6 +170,22 @@ pub mod gguf; #[allow(missing_docs)] pub mod jina; +/// Shared model primitives — safetensors, SIMD layers, API types. +#[allow(missing_docs)] +pub mod models; + +/// GPT-2 inference engine — full forward pass + OpenAI-compatible API types. +#[allow(missing_docs)] +pub mod gpt2; + +/// Stable Diffusion inference — CLIP + UNet + VAE + DDIM scheduler. +#[allow(missing_docs)] +pub mod stable_diffusion; + +/// OpenChat 3.5 inference — Mistral-7B architecture (GQA + RoPE + RMSNorm + SiLU). +#[allow(missing_docs)] +pub mod openchat; + // jitson: JSON config → scan pipeline (parser, validator, template, precompile, packed) // Always available — no Cranelift dependency. #[allow(missing_docs)] diff --git a/src/hpc/models/api_types.rs b/src/hpc/models/api_types.rs new file mode 100644 index 00000000..fcb8a72f --- /dev/null +++ b/src/hpc/models/api_types.rs @@ -0,0 +1,122 @@ +//! OpenAI-compatible API types shared across all model endpoints. +//! +//! Transport-agnostic — no HTTP framework dependency. +//! Used by GPT-2 (/v1/completions), Stable Diffusion (/v1/images/generations), +//! BERT/Jina (/v1/embeddings). + +/// Token usage statistics (shared by all endpoints). +#[derive(Clone, Debug, Default)] +pub struct Usage { + pub prompt_tokens: usize, + pub completion_tokens: usize, + pub total_tokens: usize, +} + +/// Why generation stopped. +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum FinishReason { + /// Hit stop token or stop sequence. + Stop, + /// Hit max_tokens limit. + Length, + /// Content filter triggered. + ContentFilter, +} + +/// Error response envelope. +#[derive(Clone, Debug)] +pub struct ApiError { + pub message: String, + pub error_type: String, + pub code: Option, +} + +impl ApiError { + pub fn invalid_request(msg: impl Into) -> Self { + Self { + message: msg.into(), + error_type: "invalid_request_error".into(), + code: None, + } + } + + pub fn model_not_found(model: &str) -> Self { + Self { + message: format!("model '{}' not found", model), + error_type: "invalid_request_error".into(), + code: Some("model_not_found".into()), + } + } +} + +/// Model info for /v1/models listing. +#[derive(Clone, Debug)] +pub struct ModelCard { + pub id: String, + pub owned_by: String, + pub created: u64, +} + +/// Embedding data for /v1/embeddings response. +#[derive(Clone, Debug)] +pub struct EmbeddingData { + pub index: usize, + pub embedding: Vec, +} + +/// /v1/embeddings response (shared by BERT, Jina, GPT-2 wte). +#[derive(Clone, Debug)] +pub struct EmbeddingResponse { + pub model: String, + pub data: Vec, + pub usage: Usage, +} + +/// Image data for /v1/images/generations response. +#[derive(Clone, Debug)] +pub struct ImageData { + /// Base64-encoded PNG, or URL if hosted. + pub b64_json: Option, + pub url: Option, + pub revised_prompt: Option, +} + +/// /v1/images/generations response (Stable Diffusion). +#[derive(Clone, Debug)] +pub struct ImageResponse { + pub created: u64, + pub data: Vec, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_usage_default() { + let u = Usage::default(); + assert_eq!(u.prompt_tokens, 0); + assert_eq!(u.total_tokens, 0); + } + + #[test] + fn test_api_error_invalid_request() { + let e = ApiError::invalid_request("bad input"); + assert_eq!(e.error_type, "invalid_request_error"); + assert!(e.code.is_none()); + } + + #[test] + fn test_api_error_model_not_found() { + let e = ApiError::model_not_found("gpt-5"); + assert!(e.message.contains("gpt-5")); + assert_eq!(e.code.as_deref(), Some("model_not_found")); + } + + #[test] + fn test_finish_reason_eq() { + assert_eq!(FinishReason::Stop, FinishReason::Stop); + assert_ne!(FinishReason::Stop, FinishReason::Length); + assert_ne!(FinishReason::Length, FinishReason::ContentFilter); + } +} diff --git a/src/hpc/models/layers.rs b/src/hpc/models/layers.rs new file mode 100644 index 00000000..9a726780 --- /dev/null +++ b/src/hpc/models/layers.rs @@ -0,0 +1,432 @@ +//! Shared SIMD-accelerated neural network layers. +//! +//! All ops use `crate::simd::F32x16` — the ONLY SIMD interface consumers see. +//! Used by GPT-2, Stable Diffusion CLIP, BERT, and any future transformer model. + +use crate::simd::F32x16; + +/// Layer normalization with F32x16 SIMD. +/// +/// `x` is modified in-place: `x = (x - mean) / sqrt(var + eps) * weight + bias` +pub fn layer_norm(x: &mut [f32], weight: &[f32], bias: &[f32]) { + let n = x.len(); + let chunks = n / 16; + + // Mean (SIMD) + let mut sum_acc = F32x16::splat(0.0); + for c in 0..chunks { + let off = c * 16; + sum_acc = sum_acc + F32x16::from_slice(&x[off..off + 16]); + } + let mut mean = sum_acc.reduce_sum(); + for i in (chunks * 16)..n { + mean += x[i]; + } + mean /= n as f32; + + // Variance (SIMD) + let mean_vec = F32x16::splat(mean); + let mut var_acc = F32x16::splat(0.0); + for c in 0..chunks { + let off = c * 16; + let diff = F32x16::from_slice(&x[off..off + 16]) - mean_vec; + var_acc = diff.mul_add(diff, var_acc); + } + let mut var = var_acc.reduce_sum(); + for i in (chunks * 16)..n { + let d = x[i] - mean; + var += d * d; + } + var /= n as f32; + let inv_std = 1.0 / (var + 1e-5).sqrt(); + + // Normalize + scale + shift (SIMD) + let inv_std_vec = F32x16::splat(inv_std); + for c in 0..chunks { + let off = c * 16; + let val = F32x16::from_slice(&x[off..off + 16]); + let w = F32x16::from_slice(&weight[off..off + 16]); + let b = F32x16::from_slice(&bias[off..off + 16]); + let normed = (val - mean_vec) * inv_std_vec; + let result = normed * w + b; + result.copy_to_slice(&mut x[off..off + 16]); + } + for i in (chunks * 16)..n { + x[i] = (x[i] - mean) * inv_std * weight[i] + bias[i]; + } +} + +/// Group normalization with F32x16 SIMD. +/// +/// Used by UNet (Stable Diffusion). Splits channels into `num_groups`, +/// normalizes each group independently. +pub fn group_norm(x: &mut [f32], num_groups: usize, weight: &[f32], bias: &[f32]) { + let total = x.len(); + let group_size = total / num_groups; + + for g in 0..num_groups { + let start = g * group_size; + let end = start + group_size; + let group = &mut x[start..end]; + + // Mean + let mut mean = 0.0f32; + for &v in group.iter() { + mean += v; + } + mean /= group_size as f32; + + // Variance + let mut var = 0.0f32; + for &v in group.iter() { + let d = v - mean; + var += d * d; + } + var /= group_size as f32; + let inv_std = 1.0 / (var + 1e-5).sqrt(); + + // Normalize + affine + for i in 0..group_size { + let idx = start + i; + group[i] = (group[i] - mean) * inv_std * weight[idx] + bias[idx]; + } + } +} + +/// GELU activation: `x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))` +/// +/// SIMD-accelerated via F32x16 exp for tanh approximation. +pub fn gelu(x: &mut [f32]) { + let n = x.len(); + let sqrt_2_over_pi = F32x16::splat(0.7978845608); + let coeff = F32x16::splat(0.044715); + let half = F32x16::splat(0.5); + let one = F32x16::splat(1.0); + + let chunks = n / 16; + for c in 0..chunks { + let off = c * 16; + let v = F32x16::from_slice(&x[off..off + 16]); + let v3 = v * v * v; + let inner = sqrt_2_over_pi * (v + coeff * v3); + let two_inner = inner + inner; + let exp_2x = crate::simd::simd_exp_f32(two_inner); + let tanh_v = (exp_2x - one) / (exp_2x + one); + let result = v * half * (one + tanh_v); + result.copy_to_slice(&mut x[off..off + 16]); + } + for i in (chunks * 16)..n { + let v = x[i]; + let inner = 0.7978845608 * (v + 0.044715 * v * v * v); + x[i] = v * 0.5 * (1.0 + inner.tanh()); + } +} + +/// SiLU (Swish) activation: `x * sigmoid(x)`. +/// +/// Used by Stable Diffusion UNet. SIMD-accelerated via F32x16 exp. +pub fn silu(x: &mut [f32]) { + let n = x.len(); + let one = F32x16::splat(1.0); + + let chunks = n / 16; + for c in 0..chunks { + let off = c * 16; + let v = F32x16::from_slice(&x[off..off + 16]); + let neg_v = F32x16::splat(0.0) - v; + let exp_neg = crate::simd::simd_exp_f32(neg_v); + let sigmoid = one / (one + exp_neg); + let result = v * sigmoid; + result.copy_to_slice(&mut x[off..off + 16]); + } + for i in (chunks * 16)..n { + let v = x[i]; + let sig = 1.0 / (1.0 + (-v).exp()); + x[i] = v * sig; + } +} + +/// Numerically stable softmax (in-place). +pub fn softmax(x: &mut [f32]) { + let mut max_val = f32::NEG_INFINITY; + for &v in x.iter() { + if v > max_val { + max_val = v; + } + } + + let mut sum = 0.0f32; + for v in x.iter_mut() { + *v = (*v - max_val).exp(); + sum += *v; + } + + let inv_sum = 1.0 / sum; + for v in x.iter_mut() { + *v *= inv_sum; + } +} + +/// Matrix-vector multiply: `output = input @ weight + bias`. +/// +/// Weight must be PRE-TRANSPOSED to `[out_dim, in_dim]` for contiguous SIMD. +/// Use `models::safetensors::transpose_matrix()` at load time. +pub fn matmul_vec(input: &[f32], weight: &[f32], bias: &[f32], output: &mut [f32], in_dim: usize, out_dim: usize) { + let chunks = in_dim / 16; + let remainder = in_dim % 16; + + for o in 0..out_dim { + let row_offset = o * in_dim; + let mut acc = F32x16::splat(0.0); + for c in 0..chunks { + let off = c * 16; + let vi = F32x16::from_slice(&input[off..off + 16]); + let vw = F32x16::from_slice(&weight[row_offset + off..row_offset + off + 16]); + acc = vi.mul_add(vw, acc); + } + let mut dot = acc.reduce_sum(); + let tail_start = chunks * 16; + for i in 0..remainder { + dot += input[tail_start + i] * weight[row_offset + tail_start + i]; + } + output[o] = dot + bias[o]; + } +} + +/// SIMD dot product of two f32 slices (same length). +#[inline] +pub fn dot_product(a: &[f32], b: &[f32]) -> f32 { + debug_assert_eq!(a.len(), b.len()); + let n = a.len(); + let chunks = n / 16; + let mut acc = F32x16::splat(0.0); + for c in 0..chunks { + let off = c * 16; + let va = F32x16::from_slice(&a[off..off + 16]); + let vb = F32x16::from_slice(&b[off..off + 16]); + acc = va.mul_add(vb, acc); + } + let mut sum = acc.reduce_sum(); + for i in (chunks * 16)..n { + sum += a[i] * b[i]; + } + sum +} + +/// RMS normalization (Mistral/Llama style): `x = x * weight / sqrt(mean(x²) + eps)` +/// +/// No mean subtraction, no bias. Simpler and faster than LayerNorm. +/// Used by OpenChat 3.5, Mistral, Llama 2/3. +pub fn rms_norm(x: &mut [f32], weight: &[f32], eps: f32) { + let n = x.len(); + let chunks = n / 16; + + // Mean of squares (SIMD) + let mut sq_acc = F32x16::splat(0.0); + for c in 0..chunks { + let off = c * 16; + let v = F32x16::from_slice(&x[off..off + 16]); + sq_acc = v.mul_add(v, sq_acc); + } + let mut mean_sq = sq_acc.reduce_sum(); + for i in (chunks * 16)..n { + mean_sq += x[i] * x[i]; + } + mean_sq /= n as f32; + + let inv_rms = 1.0 / (mean_sq + eps).sqrt(); + let inv_rms_vec = F32x16::splat(inv_rms); + + // Normalize × weight (SIMD) + for c in 0..chunks { + let off = c * 16; + let v = F32x16::from_slice(&x[off..off + 16]); + let w = F32x16::from_slice(&weight[off..off + 16]); + let result = v * inv_rms_vec * w; + result.copy_to_slice(&mut x[off..off + 16]); + } + for i in (chunks * 16)..n { + x[i] = x[i] * inv_rms * weight[i]; + } +} + +/// Apply Rotary Positional Embedding (RoPE) to Q and K vectors. +/// +/// Rotates pairs of dimensions by position-dependent angles: +/// `(q[2i], q[2i+1]) = R(θ_i × pos) × (q[2i], q[2i+1])` +/// where θ_i = 10000^(-2i/d). +/// +/// Used by Mistral, Llama, OpenChat (replaces learned positional embeddings). +pub fn rope_apply(q: &mut [f32], k: &mut [f32], head_dim: usize, position: usize, rope_theta: f32) { + let half = head_dim / 2; + for i in 0..half { + let theta = rope_theta.powf(-(2.0 * i as f32) / head_dim as f32); + let angle = position as f32 * theta; + let cos_a = angle.cos(); + let sin_a = angle.sin(); + + // Apply to Q + let q0 = q[2 * i]; + let q1 = q[2 * i + 1]; + q[2 * i] = q0 * cos_a - q1 * sin_a; + q[2 * i + 1] = q0 * sin_a + q1 * cos_a; + + // Apply to K + let k0 = k[2 * i]; + let k1 = k[2 * i + 1]; + k[2 * i] = k0 * cos_a - k1 * sin_a; + k[2 * i + 1] = k0 * sin_a + k1 * cos_a; + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_layer_norm_zero_mean() { + let mut x = vec![1.0, 2.0, 3.0, 4.0]; + let w = vec![1.0; 4]; + let b = vec![0.0; 4]; + layer_norm(&mut x, &w, &b); + let mean: f32 = x.iter().sum::() / 4.0; + assert!(mean.abs() < 0.01); + } + + #[test] + fn test_gelu_zero() { + let mut x = vec![0.0f32; 16]; + gelu(&mut x); + assert!(x[0].abs() < 0.01); + } + + #[test] + fn test_gelu_positive() { + let mut x = vec![2.0f32; 16]; + gelu(&mut x); + assert!((x[0] - 1.9545).abs() < 0.01); + } + + #[test] + fn test_silu_zero() { + let mut x = vec![0.0f32; 16]; + silu(&mut x); + assert!(x[0].abs() < 0.01, "SiLU(0) = 0"); + } + + #[test] + fn test_silu_positive() { + let mut x = vec![2.0f32; 16]; + silu(&mut x); + // SiLU(2) = 2 * sigmoid(2) ≈ 2 * 0.8808 ≈ 1.7616 + assert!((x[0] - 1.7616).abs() < 0.01, "SiLU(2) ≈ 1.76, got {}", x[0]); + } + + #[test] + fn test_softmax_sums_to_one() { + let mut x = vec![1.0, 2.0, 3.0, 4.0]; + softmax(&mut x); + let sum: f32 = x.iter().sum(); + assert!((sum - 1.0).abs() < 1e-5); + } + + #[test] + fn test_group_norm_two_groups() { + let mut x = vec![1.0, 2.0, 3.0, 4.0]; // 2 groups of 2 + let w = vec![1.0; 4]; + let b = vec![0.0; 4]; + group_norm(&mut x, 2, &w, &b); + // Each group normalized independently + let g1_mean = (x[0] + x[1]) / 2.0; + let g2_mean = (x[2] + x[3]) / 2.0; + assert!(g1_mean.abs() < 0.01); + assert!(g2_mean.abs() < 0.01); + } + + #[test] + fn test_dot_product_basic() { + let a = vec![1.0, 2.0, 3.0]; + let b = vec![4.0, 5.0, 6.0]; + let d = dot_product(&a, &b); + assert!((d - 32.0).abs() < 1e-5); // 4+10+18 = 32 + } + + #[test] + fn test_dot_product_simd_path() { + let a: Vec = (0..48).map(|i| i as f32).collect(); + let b: Vec = (0..48).map(|i| (i * 2) as f32).collect(); + let d = dot_product(&a, &b); + let expected: f32 = (0..48).map(|i| (i * i * 2) as f32).sum(); + assert!((d - expected).abs() < 1.0); + } + + #[test] + fn test_matmul_vec_identity() { + // 2×2 identity matrix (pre-transposed = still identity) + let input = vec![3.0, 7.0]; + let weight = vec![1.0, 0.0, 0.0, 1.0]; // [out=2, in=2] + let bias = vec![0.0, 0.0]; + let mut output = vec![0.0; 2]; + matmul_vec(&input, &weight, &bias, &mut output, 2, 2); + assert!((output[0] - 3.0).abs() < 1e-5); + assert!((output[1] - 7.0).abs() < 1e-5); + } + + #[test] + fn test_rms_norm_unit_weight() { + let mut x = vec![3.0, 4.0]; // rms = sqrt((9+16)/2) = sqrt(12.5) ≈ 3.536 + let w = vec![1.0; 2]; + rms_norm(&mut x, &w, 1e-5); + let rms = (12.5f32).sqrt(); + assert!((x[0] - 3.0 / rms).abs() < 0.01); + assert!((x[1] - 4.0 / rms).abs() < 0.01); + } + + #[test] + fn test_rms_norm_scaling() { + let mut x = vec![1.0, 1.0, 1.0, 1.0]; + let w = vec![2.0; 4]; + rms_norm(&mut x, &w, 1e-5); + // rms = 1.0, so result = 1.0 * 2.0 = 2.0 + assert!((x[0] - 2.0).abs() < 0.01); + } + + #[test] + fn test_rope_position_zero_identity() { + let mut q = vec![1.0, 2.0, 3.0, 4.0]; + let mut k = vec![5.0, 6.0, 7.0, 8.0]; + let orig_q = q.clone(); + let orig_k = k.clone(); + rope_apply(&mut q, &mut k, 4, 0, 10000.0); + // At position 0, angle = 0, cos=1, sin=0 → identity + for i in 0..4 { + assert!((q[i] - orig_q[i]).abs() < 1e-5); + assert!((k[i] - orig_k[i]).abs() < 1e-5); + } + } + + #[test] + fn test_rope_changes_with_position() { + let mut q1 = vec![1.0, 0.0, 1.0, 0.0]; + let mut k1 = vec![1.0, 0.0, 1.0, 0.0]; + let mut q2 = q1.clone(); + let mut k2 = k1.clone(); + rope_apply(&mut q1, &mut k1, 4, 1, 10000.0); + rope_apply(&mut q2, &mut k2, 4, 100, 10000.0); + // Different positions should give different results + let diff: f32 = q1.iter().zip(&q2).map(|(a, b)| (a - b).abs()).sum(); + assert!(diff > 0.01, "different positions should produce different embeddings"); + } + + #[test] + fn test_rope_preserves_norm() { + let mut q = vec![3.0, 4.0, 1.0, 2.0]; + let mut k = vec![0.0; 4]; + let norm_before: f32 = q.iter().map(|x| x * x).sum::().sqrt(); + rope_apply(&mut q, &mut k, 4, 42, 10000.0); + let norm_after: f32 = q.iter().map(|x| x * x).sum::().sqrt(); + // RoPE is a rotation — should preserve L2 norm + assert!((norm_before - norm_after).abs() < 0.01, + "RoPE should preserve norm: {} vs {}", norm_before, norm_after); + } +} diff --git a/src/hpc/models/mod.rs b/src/hpc/models/mod.rs new file mode 100644 index 00000000..5310f3c9 --- /dev/null +++ b/src/hpc/models/mod.rs @@ -0,0 +1,10 @@ +//! Shared model primitives — used by GPT-2, Stable Diffusion, BERT, Jina. +//! +//! Extracts common patterns so each model crate is thin: +//! - `safetensors`: generic file loader (header parse + tensor extract) +//! - `layers`: SIMD-accelerated ops (LayerNorm, GELU, softmax, matmul) +//! - `api_types`: OpenAI-compatible request/response envelope + +pub mod safetensors; +pub mod layers; +pub mod api_types; diff --git a/src/hpc/models/safetensors.rs b/src/hpc/models/safetensors.rs new file mode 100644 index 00000000..46a30d5f --- /dev/null +++ b/src/hpc/models/safetensors.rs @@ -0,0 +1,205 @@ +//! Generic safetensors file loader. +//! +//! Shared between GPT-2, Stable Diffusion, BERT — any model stored +//! in HuggingFace safetensors format. +//! +//! Format: `[header_size:u64_le][header_json][tensor_data]` + +use std::collections::HashMap; + +/// Tensor metadata from safetensors header. +#[derive(Clone, Debug)] +pub struct TensorMeta { + /// Byte offset into the data section. + pub offset: usize, + /// Byte size of the tensor data. + pub size: usize, +} + +/// Parsed safetensors file: header metadata + raw bytes. +pub struct SafeTensorsFile { + /// Raw file bytes (header + data). + data: Vec, + /// Byte offset where tensor data begins. + data_start: usize, + /// Tensor name → metadata. + pub tensors: HashMap, +} + +impl SafeTensorsFile { + /// Load and parse a safetensors file. + pub fn open(path: &std::path::Path) -> Result { + let data = std::fs::read(path).map_err(|e| format!("read {}: {}", path.display(), e))?; + Self::from_bytes(data) + } + + /// Parse from in-memory bytes (for embedded weights or tests). + pub fn from_bytes(data: Vec) -> Result { + if data.len() < 8 { + return Err("file too small for safetensors header".into()); + } + + let header_size = u64::from_le_bytes([ + data[0], data[1], data[2], data[3], + data[4], data[5], data[6], data[7], + ]) as usize; + + if 8 + header_size > data.len() { + return Err(format!("header_size {} exceeds file len {}", header_size, data.len())); + } + + let header_json = std::str::from_utf8(&data[8..8 + header_size]) + .map_err(|e| format!("invalid UTF-8 in header: {}", e))?; + + let data_start = 8 + header_size; + let tensors = parse_header(header_json)?; + + Ok(Self { data, data_start, tensors }) + } + + /// Read a tensor as Vec (little-endian F32). + pub fn read_f32(&self, name: &str) -> Result, String> { + let meta = self.tensors.get(name) + .ok_or_else(|| format!("missing tensor: {}", name))?; + let start = self.data_start + meta.offset; + let end = start + meta.size; + if end > self.data.len() { + return Err(format!("tensor {} [{}, {}) exceeds file len {}", name, start, end, self.data.len())); + } + Ok(self.data[start..end] + .chunks_exact(4) + .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]])) + .collect()) + } + + /// Read a tensor as Vec stored as raw u16 (for F16 tensors). + pub fn read_f16_raw(&self, name: &str) -> Result, String> { + let meta = self.tensors.get(name) + .ok_or_else(|| format!("missing tensor: {}", name))?; + let start = self.data_start + meta.offset; + let end = start + meta.size; + if end > self.data.len() { + return Err(format!("tensor {} exceeds file", name)); + } + Ok(self.data[start..end] + .chunks_exact(2) + .map(|c| u16::from_le_bytes([c[0], c[1]])) + .collect()) + } + + /// Check if a tensor exists. + pub fn has_tensor(&self, name: &str) -> bool { + self.tensors.contains_key(name) + } + + /// List all tensor names. + pub fn tensor_names(&self) -> Vec<&str> { + self.tensors.keys().map(|s| s.as_str()).collect() + } + + /// Total data size in bytes. + pub fn data_size(&self) -> usize { + self.data.len() - self.data_start + } +} + +/// Transpose a [rows, cols] row-major matrix to [cols, rows]. +/// Used by all models to pre-transpose weights for SIMD-contiguous matmul. +pub fn transpose_matrix(data: &mut Vec, rows: usize, cols: usize) { + assert_eq!(data.len(), rows * cols); + let mut transposed = vec![0.0f32; rows * cols]; + for r in 0..rows { + for c in 0..cols { + transposed[c * rows + r] = data[r * cols + c]; + } + } + *data = transposed; +} + +/// Parse safetensors JSON header to tensor metadata. +fn parse_header(json: &str) -> Result, String> { + let mut tensors = HashMap::new(); + let mut pos = 0; + + while let Some(key_start) = json[pos..].find('"') { + let key_start = pos + key_start + 1; + let key_end = match json[key_start..].find('"') { + Some(e) => key_start + e, + None => break, + }; + let key = &json[key_start..key_end]; + pos = key_end + 1; + + // Skip __metadata__ + if key == "__metadata__" { + if let Some(end) = json[pos..].find('}') { + pos += end + 1; + } + continue; + } + + // Find data_offsets + if let Some(offsets_start) = json[pos..].find("data_offsets") { + let search_start = pos + offsets_start; + if let Some(bracket_start) = json[search_start..].find('[') { + let arr_start = search_start + bracket_start + 1; + if let Some(bracket_end) = json[arr_start..].find(']') { + let arr = &json[arr_start..arr_start + bracket_end]; + let nums: Vec = arr.split(',') + .filter_map(|s| s.trim().parse().ok()) + .collect(); + if nums.len() == 2 { + tensors.insert(key.to_string(), TensorMeta { + offset: nums[0], + size: nums[1] - nums[0], + }); + } + } + } + } + + if let Some(brace) = json[pos..].find('}') { + pos += brace + 1; + } + } + + Ok(tensors) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_header_basic() { + let json = r#"{"tensor_a": {"dtype": "F32", "shape": [4], "data_offsets": [0, 16]}, "tensor_b": {"dtype": "F32", "shape": [2], "data_offsets": [16, 24]}}"#; + let tensors = parse_header(json).unwrap(); + assert_eq!(tensors.len(), 2); + assert_eq!(tensors["tensor_a"].offset, 0); + assert_eq!(tensors["tensor_a"].size, 16); + assert_eq!(tensors["tensor_b"].offset, 16); + assert_eq!(tensors["tensor_b"].size, 8); + } + + #[test] + fn test_parse_header_with_metadata() { + let json = r#"{"__metadata__": {"format": "pt"}, "w": {"dtype": "F32", "shape": [3], "data_offsets": [0, 12]}}"#; + let tensors = parse_header(json).unwrap(); + assert_eq!(tensors.len(), 1); + assert!(tensors.contains_key("w")); + } + + #[test] + fn test_transpose_matrix() { + let mut m = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; // 2×3 + transpose_matrix(&mut m, 2, 3); + // Expected 3×2: [1,4,2,5,3,6] + assert_eq!(m, vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]); + } + + #[test] + fn test_safetensors_from_bytes_too_small() { + let err = SafeTensorsFile::from_bytes(vec![0; 4]); + assert!(err.is_err()); + } +} diff --git a/src/hpc/openchat/api.rs b/src/hpc/openchat/api.rs new file mode 100644 index 00000000..1db5d59f --- /dev/null +++ b/src/hpc/openchat/api.rs @@ -0,0 +1,228 @@ +//! OpenAI-compatible chat completions API for OpenChat 3.5. +//! +//! Implements `/v1/chat/completions` with the OpenChat template: +//! ```text +//! GPT4 Correct User: {message}<|end_of_turn|> +//! GPT4 Correct Assistant: +//! ``` + +use crate::hpc::models::api_types::{Usage, FinishReason}; +use super::inference::{GeneratedToken, OpenChatEngine}; +use super::weights::*; + +/// A chat message (role + content). +#[derive(Clone, Debug)] +pub struct ChatMessage { + pub role: ChatRole, + pub content: String, +} + +/// Chat role. +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum ChatRole { + System, + User, + Assistant, +} + +/// Request body for /v1/chat/completions. +#[derive(Clone, Debug)] +pub struct ChatCompletionRequest { + pub model: String, + pub messages: Vec, + /// Pre-tokenized prompt (built from messages via chat template). + pub prompt_tokens: Vec, + pub max_tokens: usize, + pub temperature: f32, + pub stream: bool, +} + +impl Default for ChatCompletionRequest { + fn default() -> Self { + Self { + model: "openchat_3.5".into(), + messages: Vec::new(), + prompt_tokens: Vec::new(), + max_tokens: 512, + temperature: 0.7, + stream: false, + } + } +} + +/// A chat completion choice. +#[derive(Clone, Debug)] +pub struct ChatChoice { + pub index: usize, + pub message: ChatMessage, + pub finish_reason: FinishReason, +} + +/// Response body for /v1/chat/completions. +#[derive(Clone, Debug)] +pub struct ChatCompletionResponse { + pub id: String, + pub model: String, + pub choices: Vec, + pub usage: Usage, +} + +/// OpenChat API wrapper. +pub struct OpenChatApi { + engine: OpenChatEngine, + request_counter: u64, +} + +impl OpenChatApi { + pub fn new(weights: OpenChatWeights) -> Self { + Self { + engine: OpenChatEngine::new(weights), + request_counter: 0, + } + } + + /// /v1/chat/completions handler. + pub fn chat_complete(&mut self, req: &ChatCompletionRequest) -> ChatCompletionResponse { + self.request_counter += 1; + + let generated = self.engine.generate( + &req.prompt_tokens, + req.max_tokens, + req.temperature, + ); + + let finish_reason = if generated.len() < req.max_tokens { + FinishReason::Stop + } else { + FinishReason::Length + }; + + let completion_tokens = generated.len(); + let prompt_tokens = req.prompt_tokens.len(); + + ChatCompletionResponse { + id: format!("chatcmpl-{}", self.request_counter), + model: "openchat_3.5".into(), + choices: vec![ChatChoice { + index: 0, + message: ChatMessage { + role: ChatRole::Assistant, + content: format!("[{} tokens generated]", completion_tokens), + }, + finish_reason, + }], + usage: Usage { + prompt_tokens, + completion_tokens, + total_tokens: prompt_tokens + completion_tokens, + }, + } + } + + /// Build prompt token sequence from chat messages using OpenChat template. + /// + /// Format: + /// ```text + /// GPT4 Correct User: {user_msg}<|end_of_turn|>GPT4 Correct Assistant: + /// ``` + /// + /// Returns a description of the template (actual tokenization requires SentencePiece). + pub fn format_chat_template(messages: &[ChatMessage]) -> String { + let mut prompt = String::new(); + for msg in messages { + match msg.role { + ChatRole::System => { + prompt.push_str(&msg.content); + prompt.push('\n'); + } + ChatRole::User => { + prompt.push_str(chat_template::USER_PREFIX); + prompt.push_str(&msg.content); + prompt.push_str(chat_template::EOT_TOKEN); + } + ChatRole::Assistant => { + prompt.push_str(chat_template::ASSISTANT_PREFIX); + prompt.push(' '); + prompt.push_str(&msg.content); + prompt.push_str(chat_template::EOT_TOKEN); + } + } + } + // Always end with assistant prefix to prompt generation + prompt.push_str(chat_template::ASSISTANT_PREFIX); + prompt + } + + /// Access engine for direct manipulation. + pub fn engine_mut(&mut self) -> &mut OpenChatEngine { + &mut self.engine + } + + /// Model info. + pub fn model_info() -> crate::hpc::models::api_types::ModelCard { + crate::hpc::models::api_types::ModelCard { + id: "openchat_3.5".into(), + owned_by: "openchat".into(), + created: 0, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_chat_template_format() { + let messages = vec![ + ChatMessage { role: ChatRole::User, content: "Hello!".into() }, + ]; + let prompt = OpenChatApi::format_chat_template(&messages); + assert!(prompt.contains("GPT4 Correct User: Hello!")); + assert!(prompt.contains("<|end_of_turn|>")); + assert!(prompt.ends_with("GPT4 Correct Assistant:")); + } + + #[test] + fn test_chat_template_multi_turn() { + let messages = vec![ + ChatMessage { role: ChatRole::User, content: "Hi".into() }, + ChatMessage { role: ChatRole::Assistant, content: "Hello!".into() }, + ChatMessage { role: ChatRole::User, content: "How are you?".into() }, + ]; + let prompt = OpenChatApi::format_chat_template(&messages); + // Should have two user turns and one assistant turn + assert_eq!(prompt.matches("GPT4 Correct User:").count(), 2); + assert!(prompt.contains("Hello!")); + } + + #[test] + fn test_chat_template_with_system() { + let messages = vec![ + ChatMessage { role: ChatRole::System, content: "You are helpful.".into() }, + ChatMessage { role: ChatRole::User, content: "Hi".into() }, + ]; + let prompt = OpenChatApi::format_chat_template(&messages); + assert!(prompt.starts_with("You are helpful.")); + } + + #[test] + fn test_default_request() { + let req = ChatCompletionRequest::default(); + assert_eq!(req.model, "openchat_3.5"); + assert_eq!(req.max_tokens, 512); + assert!(!req.stream); + } + + #[test] + fn test_model_info() { + let info = OpenChatApi::model_info(); + assert_eq!(info.id, "openchat_3.5"); + } + + #[test] + fn test_chat_role_eq() { + assert_eq!(ChatRole::User, ChatRole::User); + assert_ne!(ChatRole::User, ChatRole::Assistant); + } +} diff --git a/src/hpc/openchat/inference.rs b/src/hpc/openchat/inference.rs new file mode 100644 index 00000000..94b5dc0a --- /dev/null +++ b/src/hpc/openchat/inference.rs @@ -0,0 +1,430 @@ +//! OpenChat 3.5 / Mistral-7B forward pass + generation loop. +//! +//! Key differences from GPT-2: +//! - Grouped Query Attention (GQA): 32 Q heads share 8 KV heads +//! - RoPE positional encoding (no learned position embeddings) +//! - RMSNorm instead of LayerNorm +//! - SiLU activation in FFN (gated MLP: gate * up, then down) +//! - All ops via `crate::hpc::models::layers` (shared F32x16 SIMD) + +use super::weights::*; +use crate::hpc::jina::{causal, runtime}; +use crate::hpc::models::layers; +use crate::simd::F32x16; + +/// A generated token with its probability. +#[derive(Clone, Debug)] +pub struct GeneratedToken { + pub token_id: u32, + pub logprob: f32, +} + +/// CausalEdge64 emitted during attention. +#[derive(Clone, Debug)] +pub struct AttentionEdge { + pub layer: u8, + pub head: u8, + pub edge: u64, +} + +/// KV cache for one layer (GQA: only 8 KV heads cached, not 32). +#[derive(Clone)] +struct KvCache { + /// Cached keys: [seq_len, kv_dim] where kv_dim = 8 × 128 = 1024. + keys: Vec, + /// Cached values: [seq_len, kv_dim]. + values: Vec, +} + +/// OpenChat 3.5 inference engine. +pub struct OpenChatEngine { + weights: OpenChatWeights, + kv_cache: Vec, + seq_len: usize, + token_history: Vec, + /// Accumulated causal edges from attention patterns. + pub causal_edges: Vec, + /// Whether to emit CausalEdge64 from attention patterns. + pub emit_causal_edges: bool, +} + +impl OpenChatEngine { + pub fn new(weights: OpenChatWeights) -> Self { + let n_layers = weights.layers.len(); + let kv_cache = (0..n_layers) + .map(|_| KvCache { + keys: Vec::with_capacity(MAX_SEQ_LEN * KV_DIM), + values: Vec::with_capacity(MAX_SEQ_LEN * KV_DIM), + }) + .collect(); + Self { + weights, + kv_cache, + seq_len: 0, + token_history: Vec::with_capacity(MAX_SEQ_LEN), + causal_edges: Vec::new(), + emit_causal_edges: false, + } + } + + /// Access weights. + pub fn weights(&self) -> &OpenChatWeights { + &self.weights + } + + /// Reset KV cache. + pub fn reset(&mut self) { + for kv in &mut self.kv_cache { + kv.keys.clear(); + kv.values.clear(); + } + self.seq_len = 0; + self.token_history.clear(); + self.causal_edges.clear(); + } + + /// Forward pass for one token → logits over vocabulary. + pub fn forward(&mut self, token_id: u32) -> Vec { + let pos = self.seq_len; + assert!(pos < MAX_SEQ_LEN, "sequence too long ({} >= {})", pos, MAX_SEQ_LEN); + self.token_history.push(token_id); + + // Token embedding (no position embedding — RoPE handles that) + let mut hidden = vec![0.0f32; EMBED_DIM]; + let emb_offset = token_id as usize * EMBED_DIM; + hidden.copy_from_slice(&self.weights.token_embd[emb_offset..emb_offset + EMBED_DIM]); + + // Transformer layers (may be fewer than NUM_LAYERS for testing/distilled models) + let n_layers = self.weights.layers.len(); + for layer_idx in 0..n_layers { + hidden = self.transformer_layer(layer_idx, &hidden, pos); + } + + // Final RMSNorm + layers::rms_norm(&mut hidden, &self.weights.output_norm, RMS_EPS); + + // Logits: hidden @ output.T (weight already in [vocab, embed] layout) + let mut logits = vec![0.0f32; VOCAB_SIZE]; + let chunks = EMBED_DIM / 16; + for v in 0..VOCAB_SIZE { + let w_off = v * EMBED_DIM; + let mut acc = F32x16::splat(0.0); + for c in 0..chunks { + let off = c * 16; + let vh = F32x16::from_slice(&hidden[off..off + 16]); + let vw = F32x16::from_slice(&self.weights.output[w_off + off..w_off + off + 16]); + acc = vh.mul_add(vw, acc); + } + logits[v] = acc.reduce_sum(); + } + + self.seq_len += 1; + logits + } + + /// One transformer layer: GQA attention + gated MLP. + fn transformer_layer(&mut self, layer_idx: usize, input: &[f32], pos: usize) -> Vec { + // Clone norm weights before mutable borrow + let attn_norm = self.weights.layers[layer_idx].attn_norm.clone(); + let ffn_norm = self.weights.layers[layer_idx].ffn_norm.clone(); + + // Pre-attention RMSNorm + let mut normed = input.to_vec(); + layers::rms_norm(&mut normed, &attn_norm, RMS_EPS); + + // GQA attention + let attn_out = self.gqa_attention(layer_idx, &normed, pos); + + // Residual + let mut hidden: Vec = input.iter().zip(&attn_out).map(|(a, b)| a + b).collect(); + + // Pre-FFN RMSNorm + let mut normed2 = hidden.clone(); + layers::rms_norm(&mut normed2, &ffn_norm, RMS_EPS); + + // Gated MLP: SiLU(gate(x)) * up(x) → down + let ffn_out = self.gated_mlp(layer_idx, &normed2); + + // Residual + for i in 0..EMBED_DIM { + hidden[i] += ffn_out[i]; + } + + hidden + } + + /// Grouped Query Attention: 32 Q heads, 8 KV heads (4:1 ratio). + /// + /// Each KV head is shared by 4 Q heads. RoPE applied to Q and K. + fn gqa_attention(&mut self, layer_idx: usize, input: &[f32], pos: usize) -> Vec { + let layer = &self.weights.layers[layer_idx]; + let zero_bias_q = vec![0.0f32; EMBED_DIM]; + let zero_bias_kv = vec![0.0f32; KV_DIM]; + + // Q projection: [4096] → [4096] (32 heads × 128D) + let mut q = vec![0.0f32; EMBED_DIM]; + layers::matmul_vec(input, &layer.attn_q, &zero_bias_q, &mut q, EMBED_DIM, EMBED_DIM); + + // K projection: [4096] → [1024] (8 heads × 128D) + let mut k = vec![0.0f32; KV_DIM]; + layers::matmul_vec(input, &layer.attn_k, &zero_bias_kv, &mut k, EMBED_DIM, KV_DIM); + + // V projection: [4096] → [1024] (8 heads × 128D) + let mut v = vec![0.0f32; KV_DIM]; + layers::matmul_vec(input, &layer.attn_v, &zero_bias_kv, &mut v, EMBED_DIM, KV_DIM); + + // Apply RoPE to Q and K (per-head) + for qh in 0..NUM_Q_HEADS { + let kv_h = qh / GQA_RATIO; + let q_off = qh * HEAD_DIM; + let k_off = kv_h * HEAD_DIM; + layers::rope_apply( + &mut q[q_off..q_off + HEAD_DIM], + &mut k[k_off..k_off + HEAD_DIM], + HEAD_DIM, + pos, + ROPE_THETA, + ); + } + + // Append K, V to cache + self.kv_cache[layer_idx].keys.extend_from_slice(&k); + self.kv_cache[layer_idx].values.extend_from_slice(&v); + + let seq_len = self.seq_len + 1; + let scale = 1.0 / (HEAD_DIM as f32).sqrt(); + + // Per Q-head attention with GQA (4 Q heads share 1 KV head) + let mut output = vec![0.0f32; EMBED_DIM]; + let emit = self.emit_causal_edges; + + for qh in 0..NUM_Q_HEADS { + let kv_h = qh / GQA_RATIO; + let q_off = qh * HEAD_DIM; + + // Scores: Q[qh] · K[kv_h]^T for all cached positions + let mut scores = vec![0.0f32; seq_len]; + for t in 0..seq_len { + let k_off = t * KV_DIM + kv_h * HEAD_DIM; + let mut dot = 0.0f32; + for d in 0..HEAD_DIM { + dot += q[q_off + d] * self.kv_cache[layer_idx].keys[k_off + d]; + } + scores[t] = dot * scale; + } + + layers::softmax(&mut scores); + + // CausalEdge64 emission + if emit { + let current_token = *self.token_history.last().unwrap_or(&0); + for t in 0..seq_len { + if scores[t] > 0.05 && t < self.token_history.len() { + let key_token = self.token_history[t]; + // Use GPT2 palette for now — OpenChat palette can be built later + // Token IDs may differ but the edge structure is the same + let edge = causal::pack_edge( + (current_token % 256) as u8, + (qh % 256) as u8, + (key_token % 256) as u8, + scores[t], + 0.3, + 0b111, // full SPO Pearl mask + self.seq_len as u16, + ); + self.causal_edges.push(AttentionEdge { + layer: layer_idx as u8, + head: qh as u8, + edge, + }); + } + } + } + + // Weighted sum of V[kv_h] + for t in 0..seq_len { + let v_off = t * KV_DIM + kv_h * HEAD_DIM; + let w = scores[t]; + for d in 0..HEAD_DIM { + output[q_off + d] += w * self.kv_cache[layer_idx].values[v_off + d]; + } + } + } + + // Output projection + let zero_bias = vec![0.0f32; EMBED_DIM]; + let mut projected = vec![0.0f32; EMBED_DIM]; + layers::matmul_vec(&output, &self.weights.layers[layer_idx].attn_output, &zero_bias, &mut projected, EMBED_DIM, EMBED_DIM); + + projected + } + + /// Gated MLP: gate(x) = SiLU(W_gate @ x), up(x) = W_up @ x. + /// output = W_down @ (gate(x) * up(x)) + fn gated_mlp(&self, layer_idx: usize, input: &[f32]) -> Vec { + let layer = &self.weights.layers[layer_idx]; + let zero_bias_mlp = vec![0.0f32; MLP_DIM]; + let zero_bias_out = vec![0.0f32; EMBED_DIM]; + + // Gate projection + SiLU + let mut gate = vec![0.0f32; MLP_DIM]; + layers::matmul_vec(input, &layer.ffn_gate, &zero_bias_mlp, &mut gate, EMBED_DIM, MLP_DIM); + layers::silu(&mut gate); + + // Up projection + let mut up = vec![0.0f32; MLP_DIM]; + layers::matmul_vec(input, &layer.ffn_up, &zero_bias_mlp, &mut up, EMBED_DIM, MLP_DIM); + + // Element-wise gate * up + let chunks = MLP_DIM / 16; + for c in 0..chunks { + let off = c * 16; + let vg = F32x16::from_slice(&gate[off..off + 16]); + let vu = F32x16::from_slice(&up[off..off + 16]); + let result = vg * vu; + result.copy_to_slice(&mut gate[off..off + 16]); + } + for i in (chunks * 16)..MLP_DIM { + gate[i] *= up[i]; + } + + // Down projection + let mut output = vec![0.0f32; EMBED_DIM]; + layers::matmul_vec(&gate, &layer.ffn_down, &zero_bias_out, &mut output, MLP_DIM, EMBED_DIM); + + output + } + + /// Generate tokens autoregressively. + pub fn generate( + &mut self, + prompt_tokens: &[u32], + max_new_tokens: usize, + temperature: f32, + ) -> Vec { + self.reset(); + let mut generated = Vec::new(); + + // Process prompt (fill KV cache) + let mut last_logits = vec![0.0f32; VOCAB_SIZE]; + for &token in prompt_tokens { + last_logits = self.forward(token); + } + + // Generate + for _ in 0..max_new_tokens { + if temperature != 1.0 && temperature > 0.0 { + let inv_temp = 1.0 / temperature; + for l in &mut last_logits { + *l *= inv_temp; + } + } + + // Greedy argmax + let mut best_id = 0u32; + let mut best_logit = f32::NEG_INFINITY; + for (i, &l) in last_logits.iter().enumerate() { + if l > best_logit { + best_logit = l; + best_id = i as u32; + } + } + + // EOS + if best_id == chat_template::EOS_TOKEN_ID { + break; + } + + generated.push(GeneratedToken { + token_id: best_id, + logprob: best_logit, + }); + + last_logits = self.forward(best_id); + } + + generated + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn dummy_weights() -> OpenChatWeights { + let layer = MistralLayerWeights { + attn_norm: vec![1.0; EMBED_DIM], + attn_q: vec![0.0; EMBED_DIM * EMBED_DIM], + attn_k: vec![0.0; KV_DIM * EMBED_DIM], + attn_v: vec![0.0; KV_DIM * EMBED_DIM], + attn_output: vec![0.0; EMBED_DIM * EMBED_DIM], + ffn_norm: vec![1.0; EMBED_DIM], + ffn_gate: vec![0.0; MLP_DIM * EMBED_DIM], + ffn_up: vec![0.0; MLP_DIM * EMBED_DIM], + ffn_down: vec![0.0; EMBED_DIM * MLP_DIM], + }; + OpenChatWeights { + token_embd: vec![0.01; VOCAB_SIZE * EMBED_DIM], + layers: vec![layer; 1], // 1 layer for testing (32 would OOM in tests) + output_norm: vec![1.0; EMBED_DIM], + output: vec![0.01; VOCAB_SIZE * EMBED_DIM], + } + } + + #[test] + fn test_engine_creation() { + let w = dummy_weights(); + let engine = OpenChatEngine::new(w); + assert_eq!(engine.seq_len, 0); + assert!(!engine.emit_causal_edges); + } + + #[test] + fn test_engine_reset() { + let w = dummy_weights(); + let mut engine = OpenChatEngine::new(w); + engine.seq_len = 5; + engine.token_history.push(42); + engine.reset(); + assert_eq!(engine.seq_len, 0); + assert!(engine.token_history.is_empty()); + } + + #[test] + fn test_gqa_ratio() { + assert_eq!(GQA_RATIO, 4, "32Q / 8KV = 4:1 sharing"); + } + + #[test] + fn test_forward_produces_logits() { + let w = dummy_weights(); + let mut engine = OpenChatEngine::new(w); + let logits = engine.forward(0); + assert_eq!(logits.len(), VOCAB_SIZE); + // With near-zero weights, logits should be near-zero + let max_logit = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max); + assert!(max_logit.is_finite(), "logits should be finite, got {}", max_logit); + } + + #[test] + fn test_forward_increments_seq_len() { + let w = dummy_weights(); + let mut engine = OpenChatEngine::new(w); + engine.forward(0); + assert_eq!(engine.seq_len, 1); + engine.forward(1); + assert_eq!(engine.seq_len, 2); + } + + #[test] + fn test_kv_cache_grows() { + let w = dummy_weights(); + let mut engine = OpenChatEngine::new(w); + engine.forward(0); + // KV cache should have 1 entry per layer + assert_eq!(engine.kv_cache[0].keys.len(), KV_DIM); + assert_eq!(engine.kv_cache[0].values.len(), KV_DIM); + engine.forward(1); + assert_eq!(engine.kv_cache[0].keys.len(), 2 * KV_DIM); + } +} diff --git a/src/hpc/openchat/mod.rs b/src/hpc/openchat/mod.rs new file mode 100644 index 00000000..d3433b88 --- /dev/null +++ b/src/hpc/openchat/mod.rs @@ -0,0 +1,39 @@ +//! OpenChat 3.5 inference engine — Mistral-7B architecture on CPU. +//! +//! OpenChat 3.5 is a Mistral-7B fine-tune with ChatGPT-3.5-level performance. +//! Same architecture as Mistral: GQA + RoPE + RMSNorm + SiLU. +//! +//! # Architecture differences from GPT-2 +//! +//! ```text +//! GPT-2: LayerNorm → MHA(12 heads) → GELU → 768D +//! OpenChat: RMSNorm → GQA(32Q/8KV) → SiLU → 4096D +//! ``` +//! +//! | Feature | GPT-2 | OpenChat/Mistral-7B | +//! |-----------------|-----------------|---------------------| +//! | Params | 124M | 7B | +//! | Embed dim | 768 | 4096 | +//! | Layers | 12 | 32 | +//! | Q heads | 12 (MHA) | 32 (GQA) | +//! | KV heads | 12 | 8 (4:1 ratio) | +//! | Activation | GELU | SiLU | +//! | Positional | Learned | RoPE (θ=10000) | +//! | Norm | Pre-LayerNorm | RMSNorm | +//! | Vocab | 50,257 (BPE) | 32,000 (SPM) | +//! | Weight format | Safetensors | GGUF (Q4_K_M) | +//! +//! # Integration +//! +//! All ops via `crate::hpc::models::layers` (shared F32x16 SIMD): +//! - `rms_norm()` — RMSNorm +//! - `rope_apply()` — Rotary Positional Embedding +//! - `silu()` — SiLU/Swish activation +//! - `softmax()`, `matmul_vec()`, `dot_product()` — standard +//! +//! Weight loading via `crate::hpc::gguf` (Q4_K_M dequantization). +//! Codec via `crate::hpc::jina::runtime` (HHTL/CausalEdge64 when available). + +pub mod weights; +pub mod inference; +pub mod api; diff --git a/src/hpc/openchat/weights.rs b/src/hpc/openchat/weights.rs new file mode 100644 index 00000000..1fac817c --- /dev/null +++ b/src/hpc/openchat/weights.rs @@ -0,0 +1,183 @@ +//! OpenChat 3.5 / Mistral-7B weight loading from GGUF format. +//! +//! Uses `crate::hpc::gguf` for dequantization (Q4_K_M, Q8_0, F16). +//! No weights stored in the binary — loaded at runtime from user-provided GGUF. +//! +//! GGUF tensor naming convention (llama.cpp style): +//! ```text +//! token_embd.weight → [32000, 4096] +//! blk.{i}.attn_q.weight → [4096, 4096] +//! blk.{i}.attn_k.weight → [1024, 4096] (GQA: 8 KV heads × 128D) +//! blk.{i}.attn_v.weight → [1024, 4096] +//! blk.{i}.attn_output.weight → [4096, 4096] +//! blk.{i}.attn_norm.weight → [4096] (RMSNorm, no bias) +//! blk.{i}.ffn_gate.weight → [14336, 4096] (SiLU gate) +//! blk.{i}.ffn_up.weight → [14336, 4096] (up projection) +//! blk.{i}.ffn_down.weight → [4096, 14336] (down projection) +//! blk.{i}.ffn_norm.weight → [4096] (RMSNorm) +//! output_norm.weight → [4096] +//! output.weight → [32000, 4096] (or tied to token_embd) +//! ``` + +use crate::hpc::gguf::{self, GgufFile}; +use crate::hpc::models::safetensors::transpose_matrix; + +/// Mistral-7B / OpenChat 3.5 configuration. +pub const VOCAB_SIZE: usize = 32000; +pub const EMBED_DIM: usize = 4096; +pub const NUM_LAYERS: usize = 32; +pub const NUM_Q_HEADS: usize = 32; +pub const NUM_KV_HEADS: usize = 8; +pub const HEAD_DIM: usize = EMBED_DIM / NUM_Q_HEADS; // 128 +pub const KV_DIM: usize = NUM_KV_HEADS * HEAD_DIM; // 1024 +pub const MLP_DIM: usize = 14336; // Mistral uses 14336 (not 4× embed) +pub const MAX_SEQ_LEN: usize = 8192; // Mistral supports 8K context (32K with sliding window) +pub const ROPE_THETA: f32 = 10000.0; +pub const RMS_EPS: f32 = 1e-5; +pub const GQA_RATIO: usize = NUM_Q_HEADS / NUM_KV_HEADS; // 4 + +/// Weights for one Mistral transformer layer. +#[derive(Clone)] +pub struct MistralLayerWeights { + /// Attention RMSNorm weight [4096] (no bias). + pub attn_norm: Vec, + /// Q projection: [4096, 4096] → pre-transposed to [4096, 4096]. + pub attn_q: Vec, + /// K projection: [1024, 4096] → pre-transposed to [1024, 4096]. + pub attn_k: Vec, + /// V projection: [1024, 4096] → pre-transposed to [1024, 4096]. + pub attn_v: Vec, + /// Output projection: [4096, 4096] → pre-transposed. + pub attn_output: Vec, + /// FFN RMSNorm weight [4096]. + pub ffn_norm: Vec, + /// Gate projection (SiLU): [14336, 4096] → pre-transposed. + pub ffn_gate: Vec, + /// Up projection: [14336, 4096] → pre-transposed. + pub ffn_up: Vec, + /// Down projection: [4096, 14336] → pre-transposed. + pub ffn_down: Vec, +} + +/// Complete OpenChat/Mistral-7B model weights. +#[derive(Clone)] +pub struct OpenChatWeights { + /// Token embedding: [32000, 4096]. + pub token_embd: Vec, + /// Transformer layers. + pub layers: Vec, + /// Final RMSNorm weight [4096]. + pub output_norm: Vec, + /// Output projection (lm_head): [32000, 4096]. + /// May be tied to token_embd (same data). + pub output: Vec, +} + +impl OpenChatWeights { + /// Load from a GGUF file (e.g., openchat_3.5.Q4_K_M.gguf). + /// + /// Dequantizes all tensors to f32 on load. For Q4_K_M (~4.4GB GGUF), + /// the f32 model will use ~28GB RAM. For Q8_0 (~7.7GB GGUF), ~28GB. + /// + /// Pre-transposes weight matrices for SIMD-contiguous `matmul_vec`. + pub fn from_gguf(path: &std::path::Path) -> Result { + let mut file = std::fs::File::open(path) + .map_err(|e| format!("open {}: {}", path.display(), e))?; + let header = gguf::read_gguf_header(&mut file)?; + + let mut read = |name: &str| -> Result, String> { + let tensor = gguf::find_tensor(&header, name) + .ok_or_else(|| format!("missing tensor: {}", name))?; + gguf::read_tensor_f32(&mut file, &header, tensor) + }; + + let token_embd = read("token_embd.weight")?; + let output_norm = read("output_norm.weight")?; + + // Output may be tied to token_embd + let output = if header.tensors.iter().any(|t| t.name == "output.weight") { + read("output.weight")? + } else { + token_embd.clone() + }; + + let mut layers = Vec::with_capacity(NUM_LAYERS); + for i in 0..NUM_LAYERS { + let mut attn_q = read(&format!("blk.{}.attn_q.weight", i))?; + let mut attn_k = read(&format!("blk.{}.attn_k.weight", i))?; + let mut attn_v = read(&format!("blk.{}.attn_v.weight", i))?; + let mut attn_output = read(&format!("blk.{}.attn_output.weight", i))?; + let mut ffn_gate = read(&format!("blk.{}.ffn_gate.weight", i))?; + let mut ffn_up = read(&format!("blk.{}.ffn_up.weight", i))?; + let mut ffn_down = read(&format!("blk.{}.ffn_down.weight", i))?; + + // Pre-transpose for SIMD-contiguous matmul + transpose_matrix(&mut attn_q, EMBED_DIM, EMBED_DIM); + transpose_matrix(&mut attn_k, EMBED_DIM, KV_DIM); + transpose_matrix(&mut attn_v, EMBED_DIM, KV_DIM); + transpose_matrix(&mut attn_output, EMBED_DIM, EMBED_DIM); + transpose_matrix(&mut ffn_gate, EMBED_DIM, MLP_DIM); + transpose_matrix(&mut ffn_up, EMBED_DIM, MLP_DIM); + transpose_matrix(&mut ffn_down, MLP_DIM, EMBED_DIM); + + layers.push(MistralLayerWeights { + attn_norm: read(&format!("blk.{}.attn_norm.weight", i))?, + attn_q, + attn_k, + attn_v, + attn_output, + ffn_norm: read(&format!("blk.{}.ffn_norm.weight", i))?, + ffn_gate, + ffn_up, + ffn_down, + }); + } + + Ok(OpenChatWeights { + token_embd, + layers, + output_norm, + output, + }) + } +} + +/// OpenChat 3.5 chat template tokens. +pub mod chat_template { + /// Beginning of text. + pub const BOS_TOKEN_ID: u32 = 1; + /// End of text. + pub const EOS_TOKEN_ID: u32 = 2; + /// OpenChat uses "GPT4 Correct User:" / "GPT4 Correct Assistant:" markers. + /// These are tokenized sequences, not single tokens. + /// The prefix for user messages (approximate token IDs — actual depends on tokenizer). + pub const USER_PREFIX: &str = "GPT4 Correct User: "; + pub const ASSISTANT_PREFIX: &str = "GPT4 Correct Assistant:"; + /// End-of-turn token (used to separate turns in OpenChat). + pub const EOT_TOKEN: &str = "<|end_of_turn|>"; +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_config_consistency() { + assert_eq!(EMBED_DIM, NUM_Q_HEADS * HEAD_DIM); + assert_eq!(KV_DIM, NUM_KV_HEADS * HEAD_DIM); + assert_eq!(GQA_RATIO, 4); + assert_eq!(HEAD_DIM, 128); + } + + #[test] + fn test_kv_dim() { + // GQA: 8 KV heads × 128D = 1024 + assert_eq!(KV_DIM, 1024); + } + + #[test] + fn test_chat_template() { + assert!(chat_template::USER_PREFIX.contains("User")); + assert!(chat_template::ASSISTANT_PREFIX.contains("Assistant")); + } +} diff --git a/src/hpc/stable_diffusion/api.rs b/src/hpc/stable_diffusion/api.rs new file mode 100644 index 00000000..7d56eb6e --- /dev/null +++ b/src/hpc/stable_diffusion/api.rs @@ -0,0 +1,171 @@ +//! OpenAI-compatible API for Stable Diffusion (/v1/images/generations). +//! +//! Uses shared `models::api_types` for the response envelope. + +use crate::hpc::models::api_types::{ImageData, ImageResponse}; +use super::clip::ClipEncoder; +use super::scheduler::{DdimScheduler, SchedulerConfig}; +use super::unet; +use super::vae; + +/// Request body for /v1/images/generations. +#[derive(Clone, Debug)] +pub struct ImageGenerationRequest { + pub model: String, + pub prompt: String, + /// Pre-tokenized prompt (CLIP tokens). + pub prompt_tokens: Vec, + /// Image dimensions. + pub width: usize, + pub height: usize, + /// Number of diffusion steps. + pub num_steps: usize, + /// Classifier-free guidance scale. + pub guidance_scale: f32, + /// Random seed for reproducibility. + pub seed: u64, + /// Number of images to generate. + pub n: usize, +} + +impl Default for ImageGenerationRequest { + fn default() -> Self { + Self { + model: "stable-diffusion-v1-5".into(), + prompt: String::new(), + prompt_tokens: Vec::new(), + width: 512, + height: 512, + num_steps: 20, + guidance_scale: 7.5, + seed: 42, + n: 1, + } + } +} + +/// Stable Diffusion API wrapper. +pub struct StableDiffusionApi { + clip: ClipEncoder, + scheduler_config: SchedulerConfig, +} + +impl StableDiffusionApi { + pub fn new(clip: ClipEncoder) -> Self { + Self { + clip, + scheduler_config: SchedulerConfig::default(), + } + } + + /// Generate images from a text prompt. + /// + /// Full pipeline: CLIP encode → UNet denoise × N steps → VAE decode. + pub fn generate(&self, req: &ImageGenerationRequest) -> ImageResponse { + let latent_h = req.height / 8; + let latent_w = req.width / 8; + + // CLIP encode + let text_embeddings = self.clip.encode(&req.prompt_tokens); + + // Initialize scheduler + let scheduler = DdimScheduler::new(SchedulerConfig { + num_inference_steps: req.num_steps, + ..self.scheduler_config.clone() + }); + + // Initialize latent with deterministic noise from seed + let latent_size = 4 * latent_h * latent_w; + let mut latent = generate_noise(latent_size, req.seed); + + // Denoising loop + for &t in &scheduler.timesteps { + let noise_pred = unet::predict_noise(&latent, &text_embeddings, t as f32); + latent = scheduler.step(&noise_pred, t, &latent); + } + + // VAE decode + let pixels = vae::decode(&latent, latent_h, latent_w); + let rgb = vae::to_rgb_u8(&pixels, 3, req.height, req.width); + + // Encode as base64 PNG (scaffold — actual PNG encoding would be here) + let b64 = base64_placeholder(&rgb, req.width, req.height); + + ImageResponse { + created: 0, + data: vec![ImageData { + b64_json: Some(b64), + url: None, + revised_prompt: Some(req.prompt.clone()), + }], + } + } +} + +/// Deterministic noise from seed (xoshiro256++). +fn generate_noise(size: usize, seed: u64) -> Vec { + let mut state = seed; + let mut noise = Vec::with_capacity(size); + for _ in 0..size { + state = state.wrapping_mul(6364136223846793005).wrapping_add(1); + // Convert to f32 in [-1, 1] range + let bits = ((state >> 32) as u32) as f32 / u32::MAX as f32; + noise.push(bits * 2.0 - 1.0); + } + noise +} + +/// Placeholder base64 (actual PNG encoding would need a png crate). +fn base64_placeholder(rgb: &[u8], _w: usize, _h: usize) -> String { + format!("raw_rgb_{}_bytes", rgb.len()) +} + +#[cfg(test)] +mod tests { + use super::*; + use super::super::clip::{ClipWeights, CLIP_VOCAB_SIZE, CLIP_EMBED_DIM, CLIP_MAX_SEQ}; + + fn dummy_clip() -> ClipEncoder { + ClipEncoder::new(ClipWeights { + token_embedding: vec![0.0; CLIP_VOCAB_SIZE * CLIP_EMBED_DIM], + position_embedding: vec![0.0; CLIP_MAX_SEQ * CLIP_EMBED_DIM], + layers: Vec::new(), + ln_final_weight: vec![1.0; CLIP_EMBED_DIM], + ln_final_bias: vec![0.0; CLIP_EMBED_DIM], + }) + } + + #[test] + fn test_default_request() { + let req = ImageGenerationRequest::default(); + assert_eq!(req.width, 512); + assert_eq!(req.height, 512); + assert_eq!(req.num_steps, 20); + } + + #[test] + fn test_generate_returns_image() { + let api = StableDiffusionApi::new(dummy_clip()); + let req = ImageGenerationRequest { + prompt_tokens: vec![0, 1, 2], + ..Default::default() + }; + let resp = api.generate(&req); + assert_eq!(resp.data.len(), 1); + assert!(resp.data[0].b64_json.is_some()); + } + + #[test] + fn test_deterministic_noise() { + let n1 = generate_noise(100, 42); + let n2 = generate_noise(100, 42); + assert_eq!(n1, n2, "same seed should give same noise"); + } + + #[test] + fn test_different_seeds() { + let n1 = generate_noise(100, 42); + let n2 = generate_noise(100, 123); + assert_ne!(n1, n2); + } +} diff --git a/src/hpc/stable_diffusion/clip.rs b/src/hpc/stable_diffusion/clip.rs new file mode 100644 index 00000000..f238e923 --- /dev/null +++ b/src/hpc/stable_diffusion/clip.rs @@ -0,0 +1,165 @@ +//! CLIP text encoder — transforms text tokens into conditioning embeddings. +//! +//! Same transformer architecture as GPT-2 but: +//! - 77 max sequence length (not 1024) +//! - No causal mask (bidirectional attention) +//! - Output is the full sequence, not just last token +//! +//! All ops via `crate::hpc::models::layers` (shared F32x16 SIMD). + +use crate::hpc::models::layers; + +/// CLIP text encoder configuration. +pub const CLIP_VOCAB_SIZE: usize = 49408; +pub const CLIP_EMBED_DIM: usize = 768; +pub const CLIP_NUM_LAYERS: usize = 12; +pub const CLIP_NUM_HEADS: usize = 12; +pub const CLIP_HEAD_DIM: usize = CLIP_EMBED_DIM / CLIP_NUM_HEADS; +pub const CLIP_MAX_SEQ: usize = 77; +pub const CLIP_MLP_DIM: usize = 3072; + +/// Weights for one CLIP transformer layer. +#[derive(Clone)] +pub struct ClipLayerWeights { + pub ln1_weight: Vec, + pub ln1_bias: Vec, + pub attn_qkv_weight: Vec, + pub attn_qkv_bias: Vec, + pub attn_out_weight: Vec, + pub attn_out_bias: Vec, + pub ln2_weight: Vec, + pub ln2_bias: Vec, + pub mlp_fc_weight: Vec, + pub mlp_fc_bias: Vec, + pub mlp_proj_weight: Vec, + pub mlp_proj_bias: Vec, +} + +/// Complete CLIP text encoder weights. +#[derive(Clone)] +pub struct ClipWeights { + pub token_embedding: Vec, // [49408, 768] + pub position_embedding: Vec, // [77, 768] + pub layers: Vec, + pub ln_final_weight: Vec, + pub ln_final_bias: Vec, +} + +/// CLIP text encoder. +pub struct ClipEncoder { + weights: ClipWeights, +} + +impl ClipEncoder { + pub fn new(weights: ClipWeights) -> Self { + Self { weights } + } + + /// Encode token IDs → embeddings [seq_len, 768]. + /// + /// Uses bidirectional attention (no causal mask). + /// Returns the full sequence of hidden states. + pub fn encode(&self, tokens: &[u32]) -> Vec { + let seq_len = tokens.len().min(CLIP_MAX_SEQ); + let mut hidden = vec![0.0f32; seq_len * CLIP_EMBED_DIM]; + + // Token + position embedding + for (t, &token_id) in tokens.iter().take(seq_len).enumerate() { + let tok_off = token_id as usize * CLIP_EMBED_DIM; + let pos_off = t * CLIP_EMBED_DIM; + let hid_off = t * CLIP_EMBED_DIM; + for d in 0..CLIP_EMBED_DIM { + hidden[hid_off + d] = + self.weights.token_embedding[tok_off + d] + + self.weights.position_embedding[pos_off + d]; + } + } + + // Transformer layers + for layer in &self.weights.layers { + self.transformer_layer(layer, &mut hidden, seq_len); + } + + // Final layer norm (per-position) + for t in 0..seq_len { + let off = t * CLIP_EMBED_DIM; + layers::layer_norm( + &mut hidden[off..off + CLIP_EMBED_DIM], + &self.weights.ln_final_weight, + &self.weights.ln_final_bias, + ); + } + + hidden + } + + /// One transformer layer (bidirectional self-attention + MLP). + fn transformer_layer( + &self, + layer: &ClipLayerWeights, + hidden: &mut [f32], + seq_len: usize, + ) { + // Process each position through attention + MLP + // For the scaffold: simplified single-token path. + // Full implementation would do batched multi-head attention. + for t in 0..seq_len { + let off = t * CLIP_EMBED_DIM; + let mut normed = hidden[off..off + CLIP_EMBED_DIM].to_vec(); + layers::layer_norm(&mut normed, &layer.ln1_weight, &layer.ln1_bias); + + // Self-attention (simplified: each position attends to itself for scaffold) + let mut attn_out = vec![0.0f32; CLIP_EMBED_DIM]; + layers::matmul_vec( + &normed, &layer.attn_out_weight, &layer.attn_out_bias, + &mut attn_out, CLIP_EMBED_DIM, CLIP_EMBED_DIM, + ); + + // Residual + for d in 0..CLIP_EMBED_DIM { + hidden[off + d] += attn_out[d]; + } + + // MLP + let mut normed2 = hidden[off..off + CLIP_EMBED_DIM].to_vec(); + layers::layer_norm(&mut normed2, &layer.ln2_weight, &layer.ln2_bias); + + let mut fc_out = vec![0.0f32; CLIP_MLP_DIM]; + layers::matmul_vec(&normed2, &layer.mlp_fc_weight, &layer.mlp_fc_bias, &mut fc_out, CLIP_EMBED_DIM, CLIP_MLP_DIM); + layers::gelu(&mut fc_out); + + let mut proj_out = vec![0.0f32; CLIP_EMBED_DIM]; + layers::matmul_vec(&fc_out, &layer.mlp_proj_weight, &layer.mlp_proj_bias, &mut proj_out, CLIP_MLP_DIM, CLIP_EMBED_DIM); + + // Residual + for d in 0..CLIP_EMBED_DIM { + hidden[off + d] += proj_out[d]; + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_clip_config() { + assert_eq!(CLIP_EMBED_DIM, CLIP_NUM_HEADS * CLIP_HEAD_DIM); + assert_eq!(CLIP_MLP_DIM, 4 * CLIP_EMBED_DIM); + } + + #[test] + fn test_clip_encode_shape() { + let weights = ClipWeights { + token_embedding: vec![0.0; CLIP_VOCAB_SIZE * CLIP_EMBED_DIM], + position_embedding: vec![0.0; CLIP_MAX_SEQ * CLIP_EMBED_DIM], + layers: Vec::new(), // no layers = just embedding + final LN + ln_final_weight: vec![1.0; CLIP_EMBED_DIM], + ln_final_bias: vec![0.0; CLIP_EMBED_DIM], + }; + let enc = ClipEncoder::new(weights); + let out = enc.encode(&[0, 1, 2]); + assert_eq!(out.len(), 3 * CLIP_EMBED_DIM); + } +} diff --git a/src/hpc/stable_diffusion/mod.rs b/src/hpc/stable_diffusion/mod.rs new file mode 100644 index 00000000..52491dbc --- /dev/null +++ b/src/hpc/stable_diffusion/mod.rs @@ -0,0 +1,32 @@ +//! Stable Diffusion inference — text-to-image on CPU. +//! +//! Architecture: +//! ```text +//! Text prompt → CLIP tokenize → CLIP encoder (transformer, shared layers with GPT-2) +//! → text embeddings [77, 768] +//! → UNet denoiser (cross-attention + ResBlocks + GroupNorm) +//! × N diffusion steps (DDPM/DDIM scheduler) +//! → VAE decoder → RGB pixels [512, 512, 3] +//! ``` +//! +//! # Shared with GPT-2 (via `models::layers`) +//! +//! - LayerNorm, GELU, softmax, matmul — identical F32x16 SIMD ops +//! - Safetensors loader — same format +//! - CausalEdge64 — cross-attention patterns → causal edges +//! - AttentionTable — palette-based O(1) approximate attention +//! +//! # SD-specific +//! +//! - GroupNorm (via `models::layers::group_norm`) +//! - SiLU activation (via `models::layers::silu`) +//! - Conv2D (new, not in GPT-2) +//! - Noise scheduler (DDPM/DDIM) +//! - VAE encoder/decoder + +pub mod clip; +pub mod unet; +pub mod vae; +pub mod scheduler; +pub mod weights; +pub mod api; diff --git a/src/hpc/stable_diffusion/scheduler.rs b/src/hpc/stable_diffusion/scheduler.rs new file mode 100644 index 00000000..b40cce4e --- /dev/null +++ b/src/hpc/stable_diffusion/scheduler.rs @@ -0,0 +1,162 @@ +//! Noise scheduler — DDPM/DDIM diffusion scheduling. +//! +//! Controls the denoising process: how much noise to add/remove at each step. +//! The scheduler is model-agnostic — it just manages the noise schedule. + +/// Scheduler configuration. +#[derive(Clone, Debug)] +pub struct SchedulerConfig { + /// Total training timesteps (typically 1000). + pub num_train_timesteps: usize, + /// Beta schedule start. + pub beta_start: f32, + /// Beta schedule end. + pub beta_end: f32, + /// Number of inference steps (20-50 typical). + pub num_inference_steps: usize, +} + +impl Default for SchedulerConfig { + fn default() -> Self { + Self { + num_train_timesteps: 1000, + beta_start: 0.00085, + beta_end: 0.012, + num_inference_steps: 20, + } + } +} + +/// DDIM scheduler (Denoising Diffusion Implicit Models). +/// +/// Faster than DDPM — can skip steps. 20 steps ≈ 50 DDPM steps quality. +pub struct DdimScheduler { + config: SchedulerConfig, + /// Precomputed alpha cumulative products. + alphas_cumprod: Vec, + /// Timesteps for inference (evenly spaced subset of training timesteps). + pub timesteps: Vec, +} + +impl DdimScheduler { + pub fn new(config: SchedulerConfig) -> Self { + let n = config.num_train_timesteps; + + // Linear beta schedule + let betas: Vec = (0..n) + .map(|i| { + config.beta_start + (config.beta_end - config.beta_start) * i as f32 / (n - 1) as f32 + }) + .collect(); + + // Alphas = 1 - beta + let alphas: Vec = betas.iter().map(|b| 1.0 - b).collect(); + + // Cumulative product of alphas + let mut alphas_cumprod = Vec::with_capacity(n); + let mut prod = 1.0f32; + for &a in &alphas { + prod *= a; + alphas_cumprod.push(prod); + } + + // Evenly spaced timesteps for inference + let step_size = n / config.num_inference_steps; + let timesteps: Vec = (0..config.num_inference_steps) + .rev() + .map(|i| i * step_size) + .collect(); + + Self { config, alphas_cumprod, timesteps } + } + + /// Single denoising step: given model noise prediction, compute x_{t-1} from x_t. + /// + /// Returns the updated (less noisy) latent. + pub fn step(&self, model_output: &[f32], timestep: usize, sample: &[f32]) -> Vec { + let alpha_prod_t = self.alphas_cumprod[timestep]; + let sqrt_alpha = alpha_prod_t.sqrt(); + let sqrt_one_minus_alpha = (1.0 - alpha_prod_t).sqrt(); + + // Predict x_0 from noise prediction + // x_0 = (x_t - sqrt(1 - alpha) * noise) / sqrt(alpha) + let inv_sqrt_alpha = 1.0 / sqrt_alpha; + + let mut result = Vec::with_capacity(sample.len()); + for i in 0..sample.len() { + let pred_x0 = (sample[i] - sqrt_one_minus_alpha * model_output[i]) * inv_sqrt_alpha; + // For DDIM with eta=0 (deterministic): x_{t-1} directly from x_0 + result.push(pred_x0); + } + + result + } + + /// Add noise to a clean sample for a given timestep. + pub fn add_noise(&self, original: &[f32], noise: &[f32], timestep: usize) -> Vec { + let alpha = self.alphas_cumprod[timestep]; + let sqrt_alpha = alpha.sqrt(); + let sqrt_one_minus = (1.0 - alpha).sqrt(); + + original.iter().zip(noise).map(|(&x, &n)| { + sqrt_alpha * x + sqrt_one_minus * n + }).collect() + } + + /// Number of inference steps. + pub fn num_steps(&self) -> usize { + self.timesteps.len() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_default_config() { + let cfg = SchedulerConfig::default(); + assert_eq!(cfg.num_train_timesteps, 1000); + assert_eq!(cfg.num_inference_steps, 20); + } + + #[test] + fn test_ddim_timesteps() { + let sched = DdimScheduler::new(SchedulerConfig::default()); + assert_eq!(sched.timesteps.len(), 20); + // First timestep should be near 1000, last near 0 + assert!(sched.timesteps[0] > sched.timesteps[19]); + } + + #[test] + fn test_alphas_cumprod_monotonic() { + let sched = DdimScheduler::new(SchedulerConfig::default()); + // Alphas cumprod should be monotonically decreasing + for i in 1..sched.alphas_cumprod.len() { + assert!(sched.alphas_cumprod[i] <= sched.alphas_cumprod[i - 1]); + } + } + + #[test] + fn test_add_noise_identity_at_t0() { + let sched = DdimScheduler::new(SchedulerConfig::default()); + let original = vec![1.0, 2.0, 3.0]; + let noise = vec![0.5, 0.5, 0.5]; + let noisy = sched.add_noise(&original, &noise, 0); + // At t=0, alpha≈1, so noisy ≈ original + for (o, n) in original.iter().zip(&noisy) { + assert!((o - n).abs() < 0.1); + } + } + + #[test] + fn test_step_denoises() { + let sched = DdimScheduler::new(SchedulerConfig::default()); + let sample = vec![0.5f32; 4]; + let noise_pred = vec![0.1f32; 4]; + let result = sched.step(&noise_pred, 500, &sample); + assert_eq!(result.len(), 4); + // Result should be different from input + assert!((result[0] - sample[0]).abs() > 0.001); + } +} diff --git a/src/hpc/stable_diffusion/unet.rs b/src/hpc/stable_diffusion/unet.rs new file mode 100644 index 00000000..b955c545 --- /dev/null +++ b/src/hpc/stable_diffusion/unet.rs @@ -0,0 +1,177 @@ +//! UNet denoiser — iterative denoising via cross-attention + ResBlocks. +//! +//! The core of Stable Diffusion: takes noisy latent + text conditioning, +//! predicts noise to remove. Runs N times per image (20-50 steps). +//! +//! # SD-specific ops (not shared with GPT-2): +//! - Conv2D (spatial convolution) +//! - GroupNorm (via `models::layers::group_norm`) +//! - SiLU activation (via `models::layers::silu`) +//! - Cross-attention (text embeddings condition denoising) +//! - Timestep embedding (sinusoidal positional encoding for diffusion step) + +use crate::hpc::models::layers; + +/// UNet configuration for SD 1.5. +pub const LATENT_CHANNELS: usize = 4; +pub const LATENT_SIZE: usize = 64; // 512/8 = 64 (VAE downscale) +pub const MODEL_CHANNELS: usize = 320; +pub const NUM_RES_BLOCKS: usize = 2; +pub const ATTENTION_RESOLUTIONS: &[usize] = &[4, 2, 1]; // at 16×16, 32×32, 64×64 +pub const CHANNEL_MULT: &[usize] = &[1, 2, 4, 4]; // 320, 640, 1280, 1280 +pub const NUM_HEADS: usize = 8; +pub const CONTEXT_DIM: usize = 768; // CLIP output dim + +/// Timestep embedding via sinusoidal encoding. +pub fn timestep_embedding(timestep: f32, dim: usize) -> Vec { + let half = dim / 2; + let mut emb = vec![0.0f32; dim]; + let log_base = -(10000.0f32.ln()) / (half as f32 - 1.0); + + for i in 0..half { + let freq = (log_base * i as f32).exp(); + let angle = timestep * freq; + emb[i] = angle.cos(); + emb[half + i] = angle.sin(); + } + emb +} + +/// Depthwise Conv2D (3×3, padding=1, stride=1). +/// +/// Operates on [channels, height, width] layout. +/// Minimal implementation — no dilation, no groups beyond depthwise. +pub fn conv2d_3x3( + input: &[f32], + weight: &[f32], + bias: &[f32], + in_channels: usize, + out_channels: usize, + h: usize, + w: usize, +) -> Vec { + let mut output = vec![0.0f32; out_channels * h * w]; + + for oc in 0..out_channels { + for ic in 0..in_channels { + for oh in 0..h { + for ow in 0..w { + let mut sum = 0.0f32; + for kh in 0..3usize { + for kw in 0..3usize { + let ih = oh as isize + kh as isize - 1; + let iw = ow as isize + kw as isize - 1; + if ih >= 0 && ih < h as isize && iw >= 0 && iw < w as isize { + let in_idx = ic * h * w + ih as usize * w + iw as usize; + let w_idx = oc * in_channels * 9 + ic * 9 + kh * 3 + kw; + sum += input[in_idx] * weight[w_idx]; + } + } + } + output[oc * h * w + oh * w + ow] += sum; + } + } + } + // Add bias + let bias_val = bias[oc]; + for i in 0..(h * w) { + output[oc * h * w + i] += bias_val; + } + } + output +} + +/// ResBlock: GroupNorm → SiLU → Conv → GroupNorm → SiLU → Conv + skip. +pub struct ResBlockWeights { + pub norm1_weight: Vec, + pub norm1_bias: Vec, + pub conv1_weight: Vec, + pub conv1_bias: Vec, + pub norm2_weight: Vec, + pub norm2_bias: Vec, + pub conv2_weight: Vec, + pub conv2_bias: Vec, + pub channels: usize, + pub h: usize, + pub w: usize, +} + +impl ResBlockWeights { + /// Forward pass through a ResBlock. + pub fn forward(&self, input: &[f32]) -> Vec { + let c = self.channels; + let h = self.h; + let w = self.w; + + // GroupNorm → SiLU → Conv + let mut x = input.to_vec(); + layers::group_norm(&mut x, 32.min(c), &self.norm1_weight, &self.norm1_bias); + layers::silu(&mut x); + let x = conv2d_3x3(&x, &self.conv1_weight, &self.conv1_bias, c, c, h, w); + + // GroupNorm → SiLU → Conv + let mut x = x; + layers::group_norm(&mut x, 32.min(c), &self.norm2_weight, &self.norm2_bias); + layers::silu(&mut x); + let mut x = conv2d_3x3(&x, &self.conv2_weight, &self.conv2_bias, c, c, h, w); + + // Skip connection + for i in 0..x.len() { + x[i] += input[i]; + } + x + } +} + +/// Predict noise given noisy latent + text conditioning + timestep. +/// +/// This is the scaffold — full implementation would chain: +/// down_blocks → mid_block → up_blocks with skip connections. +pub fn predict_noise( + noisy_latent: &[f32], + text_embeddings: &[f32], + timestep: f32, +) -> Vec { + let _t_emb = timestep_embedding(timestep, MODEL_CHANNELS); + // Scaffold: return zero noise prediction (actual UNet weights needed) + vec![0.0f32; noisy_latent.len()] +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_timestep_embedding_shape() { + let emb = timestep_embedding(500.0, 320); + assert_eq!(emb.len(), 320); + } + + #[test] + fn test_timestep_embedding_varies() { + let e1 = timestep_embedding(100.0, 64); + let e2 = timestep_embedding(900.0, 64); + // Different timesteps should give different embeddings + let diff: f32 = e1.iter().zip(&e2).map(|(a, b)| (a - b).abs()).sum(); + assert!(diff > 0.1, "different timesteps should differ"); + } + + #[test] + fn test_conv2d_bias_only() { + // Zero weights, non-zero bias + let input = vec![0.0f32; 1 * 4 * 4]; // 1ch, 4×4 + let weight = vec![0.0f32; 1 * 1 * 9]; // 1→1, 3×3 + let bias = vec![2.0f32]; + let out = conv2d_3x3(&input, &weight, &bias, 1, 1, 4, 4); + assert_eq!(out.len(), 16); + assert!((out[0] - 2.0).abs() < 1e-5); + } + + #[test] + fn test_predict_noise_shape() { + let latent = vec![0.0f32; LATENT_CHANNELS * LATENT_SIZE * LATENT_SIZE]; + let text = vec![0.0f32; 77 * 768]; + let noise = predict_noise(&latent, &text, 500.0); + assert_eq!(noise.len(), latent.len()); + } +} diff --git a/src/hpc/stable_diffusion/vae.rs b/src/hpc/stable_diffusion/vae.rs new file mode 100644 index 00000000..c9a27ab9 --- /dev/null +++ b/src/hpc/stable_diffusion/vae.rs @@ -0,0 +1,110 @@ +//! VAE decoder — latent space [4, 64, 64] → RGB pixels [3, 512, 512]. +//! +//! The final stage: takes denoised latents and decodes to a visible image. +//! Uses Conv2D + GroupNorm + SiLU (same as UNet but simpler architecture). + +use super::unet::conv2d_3x3; +use crate::hpc::models::layers; + +/// VAE configuration. +pub const VAE_LATENT_CHANNELS: usize = 4; +pub const VAE_OUT_CHANNELS: usize = 3; // RGB +pub const VAE_SCALE_FACTOR: usize = 8; // latent is 8× smaller than output + +/// VAE decoder weights (simplified). +#[derive(Clone)] +pub struct VaeDecoderWeights { + /// Post-quantization conv: [4, mid_ch, 3, 3] + pub post_quant_conv_weight: Vec, + pub post_quant_conv_bias: Vec, + pub mid_channels: usize, + /// Final conv to RGB: [mid_ch, 3, 3, 3] + pub final_conv_weight: Vec, + pub final_conv_bias: Vec, +} + +/// Decode latent tensor to RGB image. +/// +/// Input: `[4, h, w]` latent (scaled by 1/0.18215). +/// Output: `[3, h*8, w*8]` RGB pixels in [0, 1]. +pub fn decode(latent: &[f32], h: usize, w: usize) -> Vec { + let out_h = h * VAE_SCALE_FACTOR; + let out_w = w * VAE_SCALE_FACTOR; + + // Scale latent + let mut scaled: Vec = latent.iter().map(|&x| x / 0.18215).collect(); + + // Nearest-neighbor upsample (scaffold — actual VAE uses learned upsampling) + let mut upsampled = vec![0.0f32; VAE_OUT_CHANNELS * out_h * out_w]; + for c in 0..VAE_OUT_CHANNELS.min(VAE_LATENT_CHANNELS) { + for oh in 0..out_h { + for ow in 0..out_w { + let ih = oh / VAE_SCALE_FACTOR; + let iw = ow / VAE_SCALE_FACTOR; + upsampled[c * out_h * out_w + oh * out_w + ow] = + scaled[c * h * w + ih * w + iw]; + } + } + } + + // Clamp to [0, 1] + for v in &mut upsampled { + *v = v.clamp(0.0, 1.0); + } + + upsampled +} + +/// Convert [C, H, W] float tensor to [H, W, C] u8 RGB. +pub fn to_rgb_u8(tensor: &[f32], channels: usize, h: usize, w: usize) -> Vec { + let mut rgb = vec![0u8; h * w * channels]; + for y in 0..h { + for x in 0..w { + for c in 0..channels { + let val = tensor[c * h * w + y * w + x]; + rgb[(y * w + x) * channels + c] = (val * 255.0).clamp(0.0, 255.0) as u8; + } + } + } + rgb +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_decode_shape() { + let latent = vec![0.5f32; 4 * 8 * 8]; // 4ch, 8×8 + let output = decode(&latent, 8, 8); + assert_eq!(output.len(), 3 * 64 * 64); // 3ch, 64×64 + } + + #[test] + fn test_decode_clamped() { + let latent = vec![100.0f32; 4 * 4 * 4]; + let output = decode(&latent, 4, 4); + for &v in &output { + assert!(v >= 0.0 && v <= 1.0); + } + } + + #[test] + fn test_to_rgb_u8() { + let tensor = vec![0.0f32, 0.5, 1.0, 0.0, 0.5, 1.0, 0.0, 0.5, 1.0]; // 3ch, 1×3 + let rgb = to_rgb_u8(&tensor, 3, 1, 3); + assert_eq!(rgb.len(), 9); + assert_eq!(rgb[0], 0); // R of pixel 0 + assert_eq!(rgb[1], 0); // G of pixel 0 + assert_eq!(rgb[2], 0); // B of pixel 0 + } + + #[test] + fn test_to_rgb_u8_clamp() { + let tensor = vec![-1.0f32, 2.0, 0.5]; // 3ch, 1×1 + let rgb = to_rgb_u8(&tensor, 3, 1, 1); + assert_eq!(rgb[0], 0); // clamped from -1 + assert_eq!(rgb[1], 255); // clamped from 2 + assert_eq!(rgb[2], 127); // 0.5 * 255 = 127.5 → 127 + } +} diff --git a/src/hpc/stable_diffusion/weights.rs b/src/hpc/stable_diffusion/weights.rs new file mode 100644 index 00000000..18eaf46a --- /dev/null +++ b/src/hpc/stable_diffusion/weights.rs @@ -0,0 +1,92 @@ +//! Stable Diffusion weight loading from safetensors. +//! +//! Uses the shared `models::safetensors` loader. +//! SD 1.5 has ~860M params across CLIP + UNet + VAE. +//! +//! No weights are stored in this crate — they're loaded at runtime +//! from user-provided safetensors files (disk space conscious). + +use crate::hpc::models::safetensors::{SafeTensorsFile, transpose_matrix}; +use super::clip::*; + +/// Load CLIP text encoder weights from a safetensors file. +/// +/// Expected tensor names follow HuggingFace diffusers convention: +/// `text_model.encoder.layers.{i}.self_attn.{q,k,v}_proj.weight` +pub fn load_clip_weights(file: &SafeTensorsFile) -> Result { + let token_embedding = file.read_f32("text_model.embeddings.token_embedding.weight")?; + let position_embedding = file.read_f32("text_model.embeddings.position_embedding.weight")?; + let ln_final_weight = file.read_f32("text_model.final_layer_norm.weight")?; + let ln_final_bias = file.read_f32("text_model.final_layer_norm.bias")?; + + let mut layers = Vec::with_capacity(CLIP_NUM_LAYERS); + for i in 0..CLIP_NUM_LAYERS { + let prefix = format!("text_model.encoder.layers.{}", i); + + // CLIP stores Q/K/V separately — we concatenate to match GPT-2 pattern + let q_weight = file.read_f32(&format!("{}.self_attn.q_proj.weight", prefix))?; + let k_weight = file.read_f32(&format!("{}.self_attn.k_proj.weight", prefix))?; + let v_weight = file.read_f32(&format!("{}.self_attn.v_proj.weight", prefix))?; + let q_bias = file.read_f32(&format!("{}.self_attn.q_proj.bias", prefix))?; + let k_bias = file.read_f32(&format!("{}.self_attn.k_proj.bias", prefix))?; + let v_bias = file.read_f32(&format!("{}.self_attn.v_proj.bias", prefix))?; + + // Concatenate Q/K/V into combined [768, 2304] + let mut attn_qkv_weight = Vec::with_capacity(q_weight.len() * 3); + attn_qkv_weight.extend_from_slice(&q_weight); + attn_qkv_weight.extend_from_slice(&k_weight); + attn_qkv_weight.extend_from_slice(&v_weight); + + let mut attn_qkv_bias = Vec::with_capacity(q_bias.len() * 3); + attn_qkv_bias.extend_from_slice(&q_bias); + attn_qkv_bias.extend_from_slice(&k_bias); + attn_qkv_bias.extend_from_slice(&v_bias); + + let mut attn_out_weight = file.read_f32(&format!("{}.self_attn.out_proj.weight", prefix))?; + let attn_out_bias = file.read_f32(&format!("{}.self_attn.out_proj.bias", prefix))?; + + let mut mlp_fc_weight = file.read_f32(&format!("{}.mlp.fc1.weight", prefix))?; + let mlp_fc_bias = file.read_f32(&format!("{}.mlp.fc1.bias", prefix))?; + let mut mlp_proj_weight = file.read_f32(&format!("{}.mlp.fc2.weight", prefix))?; + let mlp_proj_bias = file.read_f32(&format!("{}.mlp.fc2.bias", prefix))?; + + // Pre-transpose for SIMD-contiguous access + transpose_matrix(&mut attn_out_weight, CLIP_EMBED_DIM, CLIP_EMBED_DIM); + transpose_matrix(&mut mlp_fc_weight, CLIP_EMBED_DIM, CLIP_MLP_DIM); + transpose_matrix(&mut mlp_proj_weight, CLIP_MLP_DIM, CLIP_EMBED_DIM); + + layers.push(ClipLayerWeights { + ln1_weight: file.read_f32(&format!("{}.layer_norm1.weight", prefix))?, + ln1_bias: file.read_f32(&format!("{}.layer_norm1.bias", prefix))?, + attn_qkv_weight, + attn_qkv_bias, + attn_out_weight, + attn_out_bias, + ln2_weight: file.read_f32(&format!("{}.layer_norm2.weight", prefix))?, + ln2_bias: file.read_f32(&format!("{}.layer_norm2.bias", prefix))?, + mlp_fc_weight, + mlp_fc_bias, + mlp_proj_weight, + mlp_proj_bias, + }); + } + + Ok(ClipWeights { + token_embedding, + position_embedding, + layers, + ln_final_weight, + ln_final_bias, + }) +} + +#[cfg(test)] +mod tests { + #[test] + fn test_weight_names() { + // Just verify the naming convention compiles + let prefix = "text_model.encoder.layers.0"; + let name = format!("{}.self_attn.q_proj.weight", prefix); + assert!(name.contains("q_proj")); + } +}