Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 67 additions & 21 deletions src/hpc/gguf_indexer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -345,10 +345,7 @@ pub fn project_8rows_bf16_simd(

for bin in 0..BASE_DIM {
let c = counts[bin].max(1) as f64;
let scaled = sums[bin].mul_add(
F64x8::splat(FP_SCALE / c),
F64x8::splat(0.0),
);
let scaled = sums[bin] * F64x8::splat(FP_SCALE / c);
let clamped = scaled.round().simd_clamp(lo, hi);
let vals = clamped.to_array();
for lane in 0..8 {
Expand Down Expand Up @@ -529,7 +526,9 @@ pub fn stream_index_gguf_bf16<R: Read + Seek, W: Write>(
writer.write_all(b"BGZ7").map_err(|e| e.to_string())?;
writer.write_all(&(gguf_header.tensors.len() as u32).to_le_bytes()).map_err(|e| e.to_string())?;

// ONE reusable buffer — grows to largest tensor, never shrinks
// Reusable buffer — capped at 128 MB (64M u16 elements).
// Tensors larger than this are read in row batches.
const MAX_BUF_ELEMS: usize = 64 * 1024 * 1024; // 128 MB of u16
let mut bf16_buf: Vec<u16> = Vec::new();

for tensor in &gguf_header.tensors {
Expand All @@ -543,23 +542,64 @@ pub fn stream_index_gguf_bf16<R: Read + Seek, W: Write>(
let is_bf16 = matches!(tensor.dtype, gguf::GgmlType::BF16);

if is_bf16 {
// FAST PATH: BF16 direct — no f32 intermediate
let n_elements = read_tensor_bf16_raw(reader, &gguf_header, tensor, &mut bf16_buf)?;
// FAST PATH: BF16 direct — chunked row-batch reading.
// Caps memory at MAX_BUF_ELEMS regardless of tensor size.
// A 10.7 GB ffn_gate_exps tensor reads in ~128 MB batches.
let (n_rows, n_cols) = tensor_to_rows_dims(&tensor.dimensions, &layer_type);

// F64x8: 8 rows parallel, SIMD accumulation per halftone bin
let rows = if octave_stride > 1 {
project_tensor_bf16_simd(&bf16_buf[..n_elements], n_rows, n_cols, octave_stride)
let chunk_rows = if n_cols > 0 {
(MAX_BUF_ELEMS / n_cols).max(8).min(n_rows) // at least 8 rows (SIMD batch)
} else {
// Full precision: scalar per-row (stride=1 doesn't benefit from SIMD halftone)
let mut rows = Vec::with_capacity(n_rows);
for r in 0..n_rows {
let start = r * n_cols;
let end = (start + n_cols).min(n_elements);
rows.push(project_row_bf16_direct(&bf16_buf[start..end]));
}
rows
n_rows
};
let chunk_elems = chunk_rows * n_cols;

// Grow buffer to chunk size (not full tensor size)
if bf16_buf.len() < chunk_elems {
bf16_buf.resize(chunk_elems, 0);
}

// Seek to tensor start
let abs_offset = gguf_header.tensor_data_offset + tensor.offset;
reader.seek(std::io::SeekFrom::Start(abs_offset)).map_err(|e| e.to_string())?;

let mut rows: Vec<Base17> = Vec::with_capacity(n_rows);
let mut rows_done: usize = 0;
let is_large = n_rows > chunk_rows;

while rows_done < n_rows {
let batch_n = (n_rows - rows_done).min(chunk_rows);
let batch_elems = batch_n * n_cols;

// Read batch bytes into reusable buffer
let byte_slice = unsafe {
std::slice::from_raw_parts_mut(
bf16_buf.as_mut_ptr() as *mut u8,
batch_elems * 2,
)
};
reader.read_exact(byte_slice).map_err(|e| e.to_string())?;

// Project this batch
if octave_stride > 1 {
let batch_b17 = project_tensor_bf16_simd(
&bf16_buf[..batch_elems], batch_n, n_cols, octave_stride
);
rows.extend_from_slice(&batch_b17);
} else {
for r in 0..batch_n {
let start = r * n_cols;
rows.push(project_row_bf16_direct(&bf16_buf[start..start + n_cols]));
}
}

rows_done += batch_n;

// Progress for large tensors (every chunk)
if is_large && rows_done < n_rows {
eprintln!(" ... {}/{} rows ({:.0}%)",
rows_done, n_rows, rows_done as f64 / n_rows as f64 * 100.0);
}
}

let orig_bytes = (n_rows * n_cols * 4) as u64;
let comp_bytes = (rows.len() * Base17::BYTE_SIZE) as u64;
Expand All @@ -582,8 +622,14 @@ pub fn stream_index_gguf_bf16<R: Read + Seek, W: Write>(
stats.compressed_bytes += comp_bytes;
stats.tensors_indexed += 1;

let peak = n_elements as u64 * 2;
if peak > stats.peak_tensor_bytes { stats.peak_tensor_bytes = peak; }
let buf_bytes = chunk_elems as u64 * 2;
if buf_bytes > stats.peak_tensor_bytes { stats.peak_tensor_bytes = buf_bytes; }

// Shrink buffer if it grew past the cap (shouldn't, but defensive)
if bf16_buf.len() > MAX_BUF_ELEMS {
bf16_buf.truncate(MAX_BUF_ELEMS);
bf16_buf.shrink_to(MAX_BUF_ELEMS);
}

if let Some(cb) = callback {
cb(&tensor.name, &layer_type, orig_bytes as usize, comp_bytes as usize);
Expand Down
Loading