From 2213ce95e648a1f686f44054ed0ea10dcfda7a92 Mon Sep 17 00:00:00 2001 From: AdaWorldAPI Date: Mon, 30 Mar 2026 09:42:03 +0200 Subject: [PATCH] fix: chunked BF16 reading, buffer cap, drop fake FMA MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. mul_add with zero addend → plain multiply (was wasting an FMA slot) 2. Chunked row-batch reading for BF16 tensors: caps buffer at 128 MB regardless of tensor size. A 10.7 GB ffn_gate_exps reads in ~4.8K row batches instead of one 10.7 GB allocation. Minimum batch = 8 rows (one F64x8 SIMD width). 3. Buffer shrink_to after oversized tensors: bf16_buf is truncated back to MAX_BUF_ELEMS (64M u16 = 128 MB) if it somehow grew past. 4. Progress logging within large tensors: prints row count every chunk so you see activity during multi-minute tensor reads. read_tensor_bf16_raw() is now unused in the main path (kept for potential direct use in tests or smaller models). --- src/hpc/gguf_indexer.rs | 88 +++++++++++++++++++++++++++++++---------- 1 file changed, 67 insertions(+), 21 deletions(-) diff --git a/src/hpc/gguf_indexer.rs b/src/hpc/gguf_indexer.rs index a36fc19f..dec15ec5 100644 --- a/src/hpc/gguf_indexer.rs +++ b/src/hpc/gguf_indexer.rs @@ -345,10 +345,7 @@ pub fn project_8rows_bf16_simd( for bin in 0..BASE_DIM { let c = counts[bin].max(1) as f64; - let scaled = sums[bin].mul_add( - F64x8::splat(FP_SCALE / c), - F64x8::splat(0.0), - ); + let scaled = sums[bin] * F64x8::splat(FP_SCALE / c); let clamped = scaled.round().simd_clamp(lo, hi); let vals = clamped.to_array(); for lane in 0..8 { @@ -529,7 +526,9 @@ pub fn stream_index_gguf_bf16( writer.write_all(b"BGZ7").map_err(|e| e.to_string())?; writer.write_all(&(gguf_header.tensors.len() as u32).to_le_bytes()).map_err(|e| e.to_string())?; - // ONE reusable buffer — grows to largest tensor, never shrinks + // Reusable buffer — capped at 128 MB (64M u16 elements). + // Tensors larger than this are read in row batches. + const MAX_BUF_ELEMS: usize = 64 * 1024 * 1024; // 128 MB of u16 let mut bf16_buf: Vec = Vec::new(); for tensor in &gguf_header.tensors { @@ -543,23 +542,64 @@ pub fn stream_index_gguf_bf16( let is_bf16 = matches!(tensor.dtype, gguf::GgmlType::BF16); if is_bf16 { - // FAST PATH: BF16 direct — no f32 intermediate - let n_elements = read_tensor_bf16_raw(reader, &gguf_header, tensor, &mut bf16_buf)?; + // FAST PATH: BF16 direct — chunked row-batch reading. + // Caps memory at MAX_BUF_ELEMS regardless of tensor size. + // A 10.7 GB ffn_gate_exps tensor reads in ~128 MB batches. let (n_rows, n_cols) = tensor_to_rows_dims(&tensor.dimensions, &layer_type); - - // F64x8: 8 rows parallel, SIMD accumulation per halftone bin - let rows = if octave_stride > 1 { - project_tensor_bf16_simd(&bf16_buf[..n_elements], n_rows, n_cols, octave_stride) + let chunk_rows = if n_cols > 0 { + (MAX_BUF_ELEMS / n_cols).max(8).min(n_rows) // at least 8 rows (SIMD batch) } else { - // Full precision: scalar per-row (stride=1 doesn't benefit from SIMD halftone) - let mut rows = Vec::with_capacity(n_rows); - for r in 0..n_rows { - let start = r * n_cols; - let end = (start + n_cols).min(n_elements); - rows.push(project_row_bf16_direct(&bf16_buf[start..end])); - } - rows + n_rows }; + let chunk_elems = chunk_rows * n_cols; + + // Grow buffer to chunk size (not full tensor size) + if bf16_buf.len() < chunk_elems { + bf16_buf.resize(chunk_elems, 0); + } + + // Seek to tensor start + let abs_offset = gguf_header.tensor_data_offset + tensor.offset; + reader.seek(std::io::SeekFrom::Start(abs_offset)).map_err(|e| e.to_string())?; + + let mut rows: Vec = Vec::with_capacity(n_rows); + let mut rows_done: usize = 0; + let is_large = n_rows > chunk_rows; + + while rows_done < n_rows { + let batch_n = (n_rows - rows_done).min(chunk_rows); + let batch_elems = batch_n * n_cols; + + // Read batch bytes into reusable buffer + let byte_slice = unsafe { + std::slice::from_raw_parts_mut( + bf16_buf.as_mut_ptr() as *mut u8, + batch_elems * 2, + ) + }; + reader.read_exact(byte_slice).map_err(|e| e.to_string())?; + + // Project this batch + if octave_stride > 1 { + let batch_b17 = project_tensor_bf16_simd( + &bf16_buf[..batch_elems], batch_n, n_cols, octave_stride + ); + rows.extend_from_slice(&batch_b17); + } else { + for r in 0..batch_n { + let start = r * n_cols; + rows.push(project_row_bf16_direct(&bf16_buf[start..start + n_cols])); + } + } + + rows_done += batch_n; + + // Progress for large tensors (every chunk) + if is_large && rows_done < n_rows { + eprintln!(" ... {}/{} rows ({:.0}%)", + rows_done, n_rows, rows_done as f64 / n_rows as f64 * 100.0); + } + } let orig_bytes = (n_rows * n_cols * 4) as u64; let comp_bytes = (rows.len() * Base17::BYTE_SIZE) as u64; @@ -582,8 +622,14 @@ pub fn stream_index_gguf_bf16( stats.compressed_bytes += comp_bytes; stats.tensors_indexed += 1; - let peak = n_elements as u64 * 2; - if peak > stats.peak_tensor_bytes { stats.peak_tensor_bytes = peak; } + let buf_bytes = chunk_elems as u64 * 2; + if buf_bytes > stats.peak_tensor_bytes { stats.peak_tensor_bytes = buf_bytes; } + + // Shrink buffer if it grew past the cap (shouldn't, but defensive) + if bf16_buf.len() > MAX_BUF_ELEMS { + bf16_buf.truncate(MAX_BUF_ELEMS); + bf16_buf.shrink_to(MAX_BUF_ELEMS); + } if let Some(cb) = callback { cb(&tensor.name, &layer_type, orig_bytes as usize, comp_bytes as usize);