Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 12 additions & 121 deletions src/hpc/gguf_indexer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,25 +149,6 @@ pub fn project_row_to_base17(row: &[f32]) -> Base17 {
// BF16-direct optimizations: skip f32 intermediate, strided octave sampling
// ============================================================================

/// Halftone-dropped golden positions: keep every other step (9 of 17).
/// Well-distributed across 0..16; max gap = 3. Odd bins interpolated.
const HALFTONE_POS: [u8; 9] = {
let mut t = [0u8; 9];
let mut i = 0;
let mut j = 0;
while i < BASE_DIM {
if i % 2 == 0 {
t[j] = ((i * GOLDEN_STEP) % BASE_DIM) as u8;
j += 1;
}
i += 1;
}
t
};

/// Which of the 17 Base17 bins each halftone position maps to.
const HALFTONE_TO_BIN: [u8; 9] = [0, 2, 4, 6, 8, 10, 12, 14, 16];

/// Convert one BF16 u16 to f64. Zero allocation.
#[inline(always)]
fn bf16_to_f64(bits: u16) -> f64 {
Expand Down Expand Up @@ -205,52 +186,6 @@ pub fn project_row_bf16_direct(row: &[u16]) -> Base17 {
Base17 { dims }
}

/// Project a BF16 row with octave stride and halftone dropping.
///
/// For a 5120-element row at stride=16:
/// 302 octaves / 16 = 19 sampled × 9 halftone = 171 BF16→f64 conversions
/// vs 5120 in the full path (97% reduction).
/// Odd bins interpolated from neighbors.
pub fn project_row_bf16_strided(row: &[u16], octave_stride: usize) -> Base17 {
let d = row.len();
let n_octaves = (d + BASE_DIM - 1) / BASE_DIM;

let mut half_sum = [0.0f64; 9];
let mut half_count = [0u32; 9];

let mut octave = 0;
while octave < n_octaves {
for hi in 0..9 {
let dim = octave * BASE_DIM + HALFTONE_POS[hi] as usize;
if dim < d {
half_sum[hi] += bf16_to_f64(row[dim]);
half_count[hi] += 1;
}
}
octave += octave_stride;
}

let mut dims = [0i16; BASE_DIM];

// Even bins: direct from halftone samples
for hi in 0..9 {
let bin = HALFTONE_TO_BIN[hi] as usize;
if half_count[hi] > 0 {
let mean = half_sum[hi] / half_count[hi] as f64;
dims[bin] = (mean * FP_SCALE).round().clamp(-32768.0, 32767.0) as i16;
}
}

// Odd bins: interpolate from neighbors (circular)
for odd in (1..BASE_DIM).step_by(2) {
let left = dims[odd - 1] as i32;
let right = dims[(odd + 1) % BASE_DIM] as i32;
dims[odd] = ((left + right) / 2) as i16;
}

Base17 { dims }
}

// ── F64x8 SIMD: 8 rows → 8 Base17 in parallel ──

/// Gather 8 BF16 values from 8 rows at the same column, convert to F64x8.
Expand Down Expand Up @@ -364,11 +299,11 @@ pub fn project_1row_bf16_strided(row: &[u16], octave_stride: usize) -> Base17 {

/// Project an entire BF16 tensor to Base17 using F64x8 SIMD.
///
/// Processes 8 rows in parallel per SIMD batch. Each of the 9 halftone bins
/// holds an F64x8 accumulator (8 rows × 9 bins = 72 f64 lanes = 9 zmm registers).
/// Processes 8 rows in parallel per SIMD batch. Each of the 17 bins
/// holds an F64x8 accumulator (8 rows × 17 bins = 136 f64 lanes = 17 zmm registers).
///
/// Per sampled octave: 9 halftone positions × 8 bf16_to_f64 gathers → 9 vaddpd.
/// For 5120-col rows at stride=16: 19 octaves × 9 = 171 vaddpd per 8-row batch.
/// Per sampled octave: 17 positions × 8 bf16_to_f64 gathers → 17 vaddpd.
/// For 5120-col rows at stride=16: 19 octaves × 17 = 323 vaddpd per 8-row batch.
pub fn project_tensor_bf16_simd(
buf: &[u16],
n_rows: usize,
Expand Down Expand Up @@ -401,31 +336,6 @@ pub fn project_tensor_bf16_simd(
result
}

/// Read a BF16 tensor as raw u16 values. NO f32 conversion.
/// `buf` is reusable — caller allocates once, passes to every tensor.
pub fn read_tensor_bf16_raw<R: Read + Seek>(
reader: &mut R,
gguf_file: &gguf::GgufFile,
tensor: &gguf::TensorInfo,
buf: &mut Vec<u16>,
) -> Result<usize, String> {
let abs_offset = gguf_file.tensor_data_offset + tensor.offset;
reader.seek(std::io::SeekFrom::Start(abs_offset)).map_err(|e| e.to_string())?;

let n_elements = tensor.element_count() as usize;
if buf.len() < n_elements {
buf.resize(n_elements, 0);
}

// SAFETY: u16 and [u8; 2] have the same layout on little-endian (x86/ARM).
let byte_slice = unsafe {
std::slice::from_raw_parts_mut(buf.as_mut_ptr() as *mut u8, n_elements * 2)
};
reader.read_exact(byte_slice).map_err(|e| e.to_string())?;

Ok(n_elements)
}

/// Helper: tensor dimensions → (rows, cols) without needing data.
fn tensor_to_rows_dims(dims: &[u64], layer_type: &LayerType) -> (usize, usize) {
match layer_type {
Expand Down Expand Up @@ -1206,14 +1116,6 @@ mod tests {
}

/// Exact Scout BF16 shard sizes (verified via HuggingFace HEAD).
const SCOUT_SHARD_SIZES: [u64; 5] = [
48_940_000_000, // shard 1: layers 0-10 + embeddings
49_960_000_000, // shard 2: layers 11-21
48_660_000_000, // shard 3: layers 22-32
49_790_000_000, // shard 4: layers 33-43
18_220_000_000, // shard 5: layers 44-47 + output
];

/// Run one shard of Llama 4 Scout BF16 through the BF16-direct indexer.
///
/// Uses stream_index_gguf_bf16 with F64x8 SIMD and strided octave sampling.
Expand All @@ -1226,14 +1128,13 @@ mod tests {
let filename = format!(
"BF16/Llama-4-Scout-17B-16E-Instruct-BF16-{:05}-of-00005.gguf", shard
);
let size = SCOUT_SHARD_SIZES[(shard - 1) as usize];
let octave_stride: usize = 16; // 4 octaves higher + halftone drop
let octave_stride: usize = 16;

let url = format!("https://huggingface.co/{}/resolve/main/{}", repo, filename);
eprintln!("Streaming shard {}/5: {} ({:.2} GB)", shard, filename, size as f64 / 1e9);
eprintln!("Streaming shard {}/5: {}", shard, filename);
eprintln!(" BF16-direct, octave_stride={}, F64x8 SIMD", octave_stride);

let mut reader = HttpRangeReader::with_chunk_size(url, size, 256 * 1024 * 1024);
let mut reader = HttpRangeReader::from_hf(repo, &filename, 256 * 1024 * 1024)
.expect("failed to resolve HF URL — check repo/filename");

let out_path = format!("/tmp/llama4_scout_shard{}.bgz7", shard);
let out = std::fs::File::create(&out_path).expect("create output");
Expand Down Expand Up @@ -1296,14 +1197,6 @@ mod tests {

// ── BF16-direct optimization tests ──

#[test]
fn test_halftone_positions_coverage() {
let positions: Vec<u8> = HALFTONE_POS.to_vec();
let mut sorted = positions.clone();
sorted.sort();
assert_eq!(sorted, vec![0, 1, 3, 5, 6, 8, 10, 13, 15]);
}

#[test]
fn test_bf16_to_f64_accuracy() {
assert_eq!(bf16_to_f64(0x3F80), 1.0);
Expand All @@ -1318,7 +1211,7 @@ mod tests {
// Constant BF16 row → stride shouldn't matter
let row: Vec<u16> = vec![0x3F80; 5120]; // all 1.0
let full = project_row_bf16_direct(&row);
let strided = project_row_bf16_strided(&row, 16);
let strided = project_1row_bf16_strided(&row, 16);

for i in 0..17 {
let diff = (full.dims[i] as i32 - strided.dims[i] as i32).abs();
Expand Down Expand Up @@ -1443,14 +1336,12 @@ mod tests {
eprintln!();

for (shard_num, filename, size) in shards.iter() {
let url = format!("https://huggingface.co/{}/resolve/main/{}", repo, filename);
let out_path = format!("/tmp/llama4_maverick_shard{:02}.bgz7", shard_num);

eprintln!("━━━ Shard {:02}/18 ({:.2} GB) ━━━", shard_num, *size as f64 / 1e9);
eprintln!("━━━ Shard {:02}/18 ━━━", shard_num);

let mut reader = HttpRangeReader::with_chunk_size(
url.clone(), *size, 256 * 1024 * 1024
);
let mut reader = HttpRangeReader::from_hf(repo, filename, 256 * 1024 * 1024)
.expect("failed to resolve HF URL");

let out = std::fs::File::create(&out_path).expect("create output");
let mut writer = BufWriter::new(out);
Expand Down
Loading