diff --git a/src/hpc/gguf_indexer.rs b/src/hpc/gguf_indexer.rs index fe385281..35c01fc3 100644 --- a/src/hpc/gguf_indexer.rs +++ b/src/hpc/gguf_indexer.rs @@ -519,19 +519,34 @@ pub fn stream_index_gguf_bf16( octave_stride: usize, callback: Option<&dyn Fn(&str, &LayerType, usize, usize)>, ) -> Result { - let gguf_header = gguf::read_gguf_header(reader)?; + let header = gguf::read_gguf_header(reader)?; + stream_index_gguf_bf16_with_header(reader, writer, &header, octave_stride, callback) +} + +/// Core BF16-direct indexer — works with any pre-parsed header (GGUF or safetensors). +/// +/// The header must have: +/// - `tensor_data_offset`: absolute byte offset where tensor data starts +/// - `tensors`: Vec with name, dimensions, dtype, offset (relative to data start) +pub fn stream_index_gguf_bf16_with_header( + reader: &mut R, + writer: &mut W, + header: &gguf::GgufFile, + octave_stride: usize, + callback: Option<&dyn Fn(&str, &LayerType, usize, usize)>, +) -> Result { let mut stats = IndexStats::default(); - stats.tensors_total = gguf_header.tensors.len(); + stats.tensors_total = header.tensors.len(); writer.write_all(b"BGZ7").map_err(|e| e.to_string())?; - writer.write_all(&(gguf_header.tensors.len() as u32).to_le_bytes()).map_err(|e| e.to_string())?; + writer.write_all(&(header.tensors.len() as u32).to_le_bytes()).map_err(|e| e.to_string())?; // Reusable buffer — capped at 128 MB (64M u16 elements). // Tensors larger than this are read in row batches. const MAX_BUF_ELEMS: usize = 64 * 1024 * 1024; // 128 MB of u16 let mut bf16_buf: Vec = Vec::new(); - for tensor in &gguf_header.tensors { + for tensor in &header.tensors { let layer_type = classify_tensor(&tensor.name, &tensor.dimensions); if matches!(layer_type, LayerType::Skip | LayerType::Norm) { @@ -559,7 +574,7 @@ pub fn stream_index_gguf_bf16( } // Seek to tensor start - let abs_offset = gguf_header.tensor_data_offset + tensor.offset; + let abs_offset = header.tensor_data_offset + tensor.offset; reader.seek(std::io::SeekFrom::Start(abs_offset)).map_err(|e| e.to_string())?; let mut rows: Vec = Vec::with_capacity(n_rows); @@ -636,7 +651,7 @@ pub fn stream_index_gguf_bf16( } } else { // FALLBACK: non-BF16 — use original f32 path - let data = gguf::read_tensor_f32(reader, &gguf_header, tensor)?; + let data = gguf::read_tensor_f32(reader, &header, tensor)?; let tensor_bytes = data.len() as u64 * 4; if tensor_bytes > stats.peak_tensor_bytes { stats.peak_tensor_bytes = tensor_bytes; diff --git a/src/hpc/mod.rs b/src/hpc/mod.rs index 988e3280..cf3112dc 100644 --- a/src/hpc/mod.rs +++ b/src/hpc/mod.rs @@ -172,6 +172,10 @@ pub mod gguf; #[allow(missing_docs)] pub mod gguf_indexer; +/// Safetensors header parser + streaming indexer for BF16 model weights. +#[allow(missing_docs)] +pub mod safetensors; + /// HTTP range reader — Read + Seek over HTTP for streaming GGUF from HuggingFace. #[allow(missing_docs)] pub mod http_reader; diff --git a/src/hpc/safetensors.rs b/src/hpc/safetensors.rs new file mode 100644 index 00000000..675a0c58 --- /dev/null +++ b/src/hpc/safetensors.rs @@ -0,0 +1,414 @@ +//! Safetensors header parser — streaming support for the bgz17 indexer. +//! +//! Parses the safetensors JSON header and produces `GgufFile` + `TensorInfo` +//! so that `stream_index_gguf_bf16` works unchanged on safetensors files. +//! +//! ```text +//! Safetensors layout: +//! [8 bytes] u64 LE = header_size +//! [header_size bytes] JSON = tensor metadata +//! [remaining bytes] raw tensor data (contiguous, no padding) +//! +//! JSON structure: +//! { "__metadata__": {...}, +//! "tensor_name": { "dtype": "BF16", "shape": [d0, d1], "data_offsets": [start, end] }, +//! ... } +//! +//! data_offsets are relative to the start of the data section (byte after JSON header). +//! ``` +//! +//! The key advantage over GGUF for the reasoning diff pipeline: +//! safetensors stores full BF16 precision, while GGUF Q8_0 introduces +//! quantization noise. BF16→BF16 diff gives cleaner causal attribution. + +use super::gguf::{GgufFile, TensorInfo, GgmlType}; +use std::collections::HashMap; +use std::io::{Read, Seek, SeekFrom}; + +// ============================================================================ +// Safetensors dtype → GgmlType mapping +// ============================================================================ + +fn parse_dtype(s: &str) -> Result { + match s { + "BF16" | "bfloat16" => Ok(GgmlType::BF16), + "F16" | "float16" => Ok(GgmlType::F16), + "F32" | "float32" => Ok(GgmlType::F32), + // Safetensors also supports I8, I16, I32, I64, F64, BOOL, U8, etc. + // For weight indexing, we only care about float types. + other => Err(format!("unsupported safetensors dtype: {}", other)), + } +} + +// ============================================================================ +// JSON parser — minimal, no serde dependency +// ============================================================================ + +/// Parse a safetensors JSON header without serde. +/// +/// We only need: tensor names, dtypes, shapes, and data_offsets. +/// The JSON is always a flat object of objects. +fn parse_safetensors_json(json: &str) -> Result, String> { + let mut tensors = Vec::new(); + + // Simple state-machine JSON parser for the safetensors format. + // The format is always: { "name": { "dtype": "...", "shape": [...], "data_offsets": [a, b] }, ... } + // We skip "__metadata__" entries. + + let json = json.trim(); + if !json.starts_with('{') || !json.ends_with('}') { + return Err("invalid JSON: not an object".into()); + } + + // Find all tensor entries by scanning for "dtype" keys + // This is a pragmatic parser — not a full JSON parser. + let mut pos = 1; // skip opening { + let bytes = json.as_bytes(); + let len = bytes.len(); + + while pos < len - 1 { + // Skip whitespace and commas + while pos < len && (bytes[pos] == b' ' || bytes[pos] == b'\n' || + bytes[pos] == b'\r' || bytes[pos] == b'\t' || + bytes[pos] == b',') { + pos += 1; + } + if pos >= len - 1 { break; } + if bytes[pos] == b'}' { break; } + + // Read key (tensor name) + if bytes[pos] != b'"' { + pos += 1; + continue; + } + let key_start = pos + 1; + pos += 1; + while pos < len && bytes[pos] != b'"' { + if bytes[pos] == b'\\' { pos += 1; } // skip escaped char + pos += 1; + } + let key = &json[key_start..pos]; + pos += 1; // skip closing " + + // Skip colon + while pos < len && bytes[pos] != b':' { pos += 1; } + pos += 1; // skip : + + // Skip whitespace + while pos < len && (bytes[pos] == b' ' || bytes[pos] == b'\n' || + bytes[pos] == b'\r' || bytes[pos] == b'\t') { + pos += 1; + } + + if key == "__metadata__" { + // Skip the metadata object — find matching closing brace + let depth_start = pos; + if bytes[pos] == b'{' { + let mut depth = 1; + pos += 1; + while pos < len && depth > 0 { + if bytes[pos] == b'{' { depth += 1; } + if bytes[pos] == b'}' { depth -= 1; } + if bytes[pos] == b'"' { + pos += 1; + while pos < len && bytes[pos] != b'"' { + if bytes[pos] == b'\\' { pos += 1; } + pos += 1; + } + } + pos += 1; + } + } + continue; + } + + // Parse tensor value object: { "dtype": "...", "shape": [...], "data_offsets": [...] } + if bytes[pos] != b'{' { + // Not an object — skip until next comma or closing brace + while pos < len && bytes[pos] != b',' && bytes[pos] != b'}' { pos += 1; } + continue; + } + + // Find the closing brace for this tensor's object + let obj_start = pos; + let mut depth = 1; + pos += 1; + while pos < len && depth > 0 { + if bytes[pos] == b'{' { depth += 1; } + if bytes[pos] == b'}' { depth -= 1; } + if bytes[pos] == b'"' { + pos += 1; + while pos < len && bytes[pos] != b'"' { + if bytes[pos] == b'\\' { pos += 1; } + pos += 1; + } + } + pos += 1; + } + let obj_str = &json[obj_start..pos]; + + // Extract dtype + let dtype_str = extract_json_string(obj_str, "dtype").unwrap_or_default(); + let dtype = match parse_dtype(&dtype_str) { + Ok(d) => d, + Err(_) => continue, // skip unsupported dtypes + }; + + // Extract shape + let shape = extract_json_array_u64(obj_str, "shape").unwrap_or_default(); + + // Extract data_offsets + let offsets = extract_json_array_u64(obj_str, "data_offsets").unwrap_or_default(); + let offset = if offsets.len() >= 1 { offsets[0] } else { 0 }; + + tensors.push(TensorInfo { + name: key.to_string(), + dimensions: shape, + dtype, + offset, + }); + } + + // Sort by offset for sequential reading + tensors.sort_by_key(|t| t.offset); + + Ok(tensors) +} + +/// Extract a string value for a key from a JSON object fragment. +fn extract_json_string(obj: &str, key: &str) -> Option { + let pattern = format!("\"{}\"", key); + let pos = obj.find(&pattern)?; + let after_key = &obj[pos + pattern.len()..]; + + // Find colon then opening quote + let colon = after_key.find(':')?; + let rest = &after_key[colon + 1..]; + let quote1 = rest.find('"')?; + let rest = &rest[quote1 + 1..]; + let quote2 = rest.find('"')?; + + Some(rest[..quote2].to_string()) +} + +/// Extract a u64 array value for a key from a JSON object fragment. +fn extract_json_array_u64(obj: &str, key: &str) -> Option> { + let pattern = format!("\"{}\"", key); + let pos = obj.find(&pattern)?; + let after_key = &obj[pos + pattern.len()..]; + + let bracket_open = after_key.find('[')?; + let bracket_close = after_key.find(']')?; + let array_str = &after_key[bracket_open + 1..bracket_close]; + + let values: Vec = array_str.split(',') + .filter_map(|s| s.trim().parse().ok()) + .collect(); + + Some(values) +} + +// ============================================================================ +// Header reader +// ============================================================================ + +/// Read a safetensors file header and produce a GgufFile-compatible struct. +/// +/// The returned `GgufFile` has: +/// - `tensor_data_offset`: absolute byte offset where tensor data starts +/// - `tensors`: Vec with offsets relative to data start +/// - `version`: 0 (not a GGUF version) +/// - `alignment`: 1 (safetensors has no alignment padding) +pub fn read_safetensors_header(reader: &mut R) -> Result { + // Read header size (first 8 bytes, u64 LE) + let mut size_buf = [0u8; 8]; + reader.read_exact(&mut size_buf).map_err(|e| format!("read header size: {}", e))?; + let header_size = u64::from_le_bytes(size_buf); + + if header_size > 100_000_000 { + return Err(format!("header_size {} too large (>100 MB)", header_size)); + } + + // Read JSON header + let mut json_buf = vec![0u8; header_size as usize]; + reader.read_exact(&mut json_buf).map_err(|e| format!("read header JSON: {}", e))?; + let json_str = String::from_utf8(json_buf).map_err(|e| format!("header not UTF-8: {}", e))?; + + // Parse tensors + let tensors = parse_safetensors_json(&json_str)?; + + // Data starts immediately after the header + let tensor_data_offset = 8 + header_size; + + eprintln!(" Safetensors: {} tensors, data at byte {}", + tensors.len(), tensor_data_offset); + + Ok(GgufFile { + version: 0, + metadata: HashMap::new(), + tensors, + tensor_data_offset, + alignment: 1, + }) +} + +// ============================================================================ +// Streaming indexer entry point +// ============================================================================ + +/// Stream-index a safetensors file through the BF16-direct pipeline. +/// +/// This is a thin wrapper: parse safetensors header → produce GgufFile → +/// delegate to `stream_index_gguf_bf16`. +/// +/// Why this matters: safetensors stores full BF16 weights (no quantization). +/// The GGUF Q8_0 path introduces 8-bit quantization noise before projection. +/// BF16→Base17 gives cleaner fingerprints for causal diffing. +pub fn stream_index_safetensors_bf16( + reader: &mut R, + writer: &mut W, + octave_stride: usize, + callback: Option<&dyn Fn(&str, &super::gguf_indexer::LayerType, usize, usize)>, +) -> Result { + // Parse safetensors header (produces GgufFile-compatible struct) + let header = read_safetensors_header(reader)?; + + // Delegate to the existing BF16-direct chunked indexer + // The indexer uses: header.tensors, header.tensor_data_offset, tensor.offset, tensor.dtype + // All of these are populated by read_safetensors_header identically to read_gguf_header. + super::gguf_indexer::stream_index_gguf_bf16_with_header( + reader, writer, &header, octave_stride, callback, + ) +} + +// ============================================================================ +// Tests +// ============================================================================ + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_dtype() { + assert_eq!(parse_dtype("BF16").unwrap(), GgmlType::BF16); + assert_eq!(parse_dtype("bfloat16").unwrap(), GgmlType::BF16); + assert_eq!(parse_dtype("F16").unwrap(), GgmlType::F16); + assert_eq!(parse_dtype("F32").unwrap(), GgmlType::F32); + assert!(parse_dtype("I32").is_err()); + } + + #[test] + fn test_parse_safetensors_json_minimal() { + let json = r#"{ + "__metadata__": {"format": "pt"}, + "model.embed_tokens.weight": { + "dtype": "BF16", + "shape": [151936, 3584], + "data_offsets": [0, 1089470464] + }, + "model.layers.0.self_attn.q_proj.weight": { + "dtype": "BF16", + "shape": [3584, 3584], + "data_offsets": [1089470464, 1115095040] + } + }"#; + + let tensors = parse_safetensors_json(json).unwrap(); + assert_eq!(tensors.len(), 2); + + // Sorted by offset + assert_eq!(tensors[0].name, "model.embed_tokens.weight"); + assert_eq!(tensors[0].dimensions, vec![151936, 3584]); + assert_eq!(tensors[0].dtype, GgmlType::BF16); + assert_eq!(tensors[0].offset, 0); + + assert_eq!(tensors[1].name, "model.layers.0.self_attn.q_proj.weight"); + assert_eq!(tensors[1].offset, 1089470464); + } + + #[test] + fn test_extract_json_helpers() { + let obj = r#"{"dtype": "BF16", "shape": [3584, 3584], "data_offsets": [100, 200]}"#; + assert_eq!(extract_json_string(obj, "dtype"), Some("BF16".into())); + assert_eq!(extract_json_array_u64(obj, "shape"), Some(vec![3584, 3584])); + assert_eq!(extract_json_array_u64(obj, "data_offsets"), Some(vec![100, 200])); + } + + #[test] + fn test_read_synthetic_safetensors() { + use std::io::Cursor; + + // Build a minimal safetensors file in memory + let json = r#"{"tensor_a": {"dtype": "BF16", "shape": [4, 8], "data_offsets": [0, 64]}}"#; + let json_bytes = json.as_bytes(); + let header_size = json_bytes.len() as u64; + + let mut file_bytes = Vec::new(); + file_bytes.extend_from_slice(&header_size.to_le_bytes()); + file_bytes.extend_from_slice(json_bytes); + // 64 bytes of BF16 data (4 rows × 8 cols × 2 bytes) + file_bytes.extend_from_slice(&vec![0x3F, 0x80; 32]); // 32 × BF16(1.0) = 0x3F80 + + let mut cursor = Cursor::new(file_bytes); + let header = read_safetensors_header(&mut cursor).unwrap(); + + assert_eq!(header.tensors.len(), 1); + assert_eq!(header.tensors[0].name, "tensor_a"); + assert_eq!(header.tensors[0].dimensions, vec![4, 8]); + assert_eq!(header.tensors[0].dtype, GgmlType::BF16); + assert_eq!(header.tensor_data_offset, 8 + header_size); + } + + #[test] + #[ignore] // Streams ~55 GB from HuggingFace + fn test_stream_index_qwen35_safetensors() { + use super::super::http_reader::HttpRangeReader; + use std::io::BufWriter; + + let repo = "Qwen/Qwen3.5-27B"; + let shards = 11; + + for shard in 1..=shards { + let filename = format!("model.safetensors-{:05}-of-{:05}.safetensors", shard, shards); + let out_path = format!("/tmp/qwen35_27b_base_shard{:02}.bgz7", shard); + + if std::fs::metadata(&out_path).is_ok() { + eprintln!("SKIP {} (exists)", out_path); + continue; + } + + let url = format!("https://huggingface.co/{}/resolve/main/{}", repo, filename); + eprintln!("Indexing shard {}/{}: {}", shard, shards, filename); + + // HEAD for size + let size_str = std::process::Command::new("curl") + .args(&["-sI", "-L", &url]) + .output() + .map(|o| String::from_utf8_lossy(&o.stdout).to_string()) + .unwrap_or_default(); + let size: u64 = size_str.lines() + .find(|l| l.to_lowercase().starts_with("content-length:")) + .and_then(|l| l.split(':').nth(1)) + .and_then(|s| s.trim().parse().ok()) + .unwrap_or(6_000_000_000); + + let mut reader = HttpRangeReader::with_chunk_size(url, size, 256 * 1024 * 1024); + let out = std::fs::File::create(&out_path).expect("create output"); + let mut writer = BufWriter::new(out); + + let stats = stream_index_safetensors_bf16( + &mut reader, &mut writer, 16, + Some(&|name, lt, orig, comp| { + let ratio = if comp > 0 { orig as f64 / comp as f64 } else { 0.0 }; + eprintln!(" {:50} {:>12} → {:>8} ({:.0}×)", name, orig, comp, ratio); + }), + ).expect("safetensors indexing failed"); + + drop(writer); + eprintln!(" → {:.2} MB, {} tensors", + std::fs::metadata(&out_path).map(|m| m.len()).unwrap_or(0) as f64 / 1e6, + stats.tensors_indexed); + } + } +}