From aef5af13088797eb83ebb9f80618495bb2457d04 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Renault?= Date: Tue, 12 May 2026 12:30:21 +0200 Subject: [PATCH] Introduce quad SIMD search --- roaring/src/bitmap/store/array_store/mod.rs | 80 +++++++++--- .../src/bitmap/store/array_store/vector.rs | 115 ++++++++++++++++++ 2 files changed, 176 insertions(+), 19 deletions(-) diff --git a/roaring/src/bitmap/store/array_store/mod.rs b/roaring/src/bitmap/store/array_store/mod.rs index 61967704..c8f3acf1 100644 --- a/roaring/src/bitmap/store/array_store/mod.rs +++ b/roaring/src/bitmap/store/array_store/mod.rs @@ -5,6 +5,7 @@ mod visitor; use crate::bitmap::store::array_store::visitor::{CardinalityCounter, VecWriter}; use core::cmp::Ordering; use core::cmp::Ordering::*; +use core::convert::identity; use core::fmt::{Display, Formatter}; use core::mem::size_of; use core::ops::{BitAnd, BitAndAssign, BitOr, BitXor, RangeInclusive, Sub, SubAssign}; @@ -126,7 +127,11 @@ impl ArrayStore { #[inline] pub fn insert(&mut self, index: u16) -> bool { - self.vec.binary_search(&index).map_err(|loc| self.vec.insert(loc, index)).is_err() + #[cfg(feature = "simd")] + let result = vector::quad_search(&self.vec, index); + #[cfg(not(feature = "simd"))] + let result = self.vec.binary_search(&index); + result.map_err(|loc| self.vec.insert(loc, index)).is_err() } pub fn insert_range(&mut self, range: RangeInclusive) -> u64 { @@ -134,12 +139,20 @@ impl ArrayStore { let end = *range.end(); // Figure out the starting/ending position in the vec. - let pos_start = self.vec.binary_search(&start).unwrap_or_else(|x| x); - let pos_end = pos_start - + match self.vec[pos_start..].binary_search(&end) { - Ok(x) => x + 1, - Err(x) => x, - }; + #[cfg(feature = "simd")] + let pos_start = vector::quad_search(&self.vec, start).unwrap_or_else(identity); + #[cfg(not(feature = "simd"))] + let pos_start = self.vec.binary_search(&start).unwrap_or_else(identity); + + #[cfg(feature = "simd")] + let pos_end_result = vector::quad_search(&self.vec[pos_start..], end); + #[cfg(not(feature = "simd"))] + let pos_end_result = self.vec[pos_start..].binary_search(&end); + + let pos_end = match pos_end_result { + Ok(x) => x + pos_start + 1, + Err(x) => x + pos_start, + }; // Overwrite the range in the middle - there's no need to take // into account any existing elements between start and end, as @@ -175,7 +188,12 @@ impl ArrayStore { } pub fn remove(&mut self, index: u16) -> bool { - self.vec.binary_search(&index).map(|loc| self.vec.remove(loc)).is_ok() + #[cfg(feature = "simd")] + let result = vector::quad_search(&self.vec, index); + #[cfg(not(feature = "simd"))] + let result = self.vec.binary_search(&index); + + result.map(|loc| self.vec.remove(loc)).is_ok() } pub fn remove_range(&mut self, range: RangeInclusive) -> u64 { @@ -183,12 +201,21 @@ impl ArrayStore { let end = *range.end(); // Figure out the starting/ending position in the vec. - let pos_start = self.vec.binary_search(&start).unwrap_or_else(|x| x); - let pos_end = pos_start - + match self.vec[pos_start..].binary_search(&end) { - Ok(x) => x + 1, - Err(x) => x, - }; + #[cfg(feature = "simd")] + let pos_start = vector::quad_search(&self.vec, start).unwrap_or_else(identity); + #[cfg(not(feature = "simd"))] + let pos_start = self.vec.binary_search(&start).unwrap_or_else(identity); + + #[cfg(feature = "simd")] + let pos_end_result = vector::quad_search(&self.vec[pos_start..], end); + #[cfg(not(feature = "simd"))] + let pos_end_result = self.vec[pos_start..].binary_search(&end); + + let pos_end = match pos_end_result { + Ok(x) => x + pos_start + 1, + Err(x) => x + pos_start, + }; + self.vec.drain(pos_start..pos_end); (pos_end - pos_start) as u64 } @@ -203,7 +230,10 @@ impl ArrayStore { } pub fn contains(&self, index: u16) -> bool { - self.vec.binary_search(&index).is_ok() + #[cfg(feature = "simd")] + return vector::quad_contains(&self.vec, index); + #[cfg(not(feature = "simd"))] + return self.vec.binary_search(&index).is_ok(); } pub fn contains_range(&self, range: RangeInclusive) -> bool { @@ -213,13 +243,20 @@ impl ArrayStore { if self.vec.len() < range_count { return false; } - let start_i = match self.vec.binary_search(&start) { + + #[cfg(feature = "simd")] + let result = vector::quad_search(&self.vec, start); + #[cfg(not(feature = "simd"))] + let result = self.vec.binary_search(&start); + + let start_i = match result { Ok(i) => i, Err(_) => return false, }; - // If there are `range_count` items, last item in the next range_count should be the - // expected end value, because this vec is sorted and has no duplicates + // If there are `range_count` items, last item in the next range_count + // should be the expected end value, because this vec is sorted and + // has no duplicates self.vec.get(start_i + range_count - 1) == Some(&end) } @@ -301,7 +338,12 @@ impl ArrayStore { } pub fn rank(&self, index: u16) -> u64 { - match self.vec.binary_search(&index) { + #[cfg(feature = "simd")] + let result = vector::quad_search(&self.vec, index); + #[cfg(not(feature = "simd"))] + let result = self.vec.binary_search(&index); + + match result { Ok(i) => i as u64 + 1, Err(i) => i as u64, } diff --git a/roaring/src/bitmap/store/array_store/vector.rs b/roaring/src/bitmap/store/array_store/vector.rs index 6fc14a9a..7262f29a 100644 --- a/roaring/src/bitmap/store/array_store/vector.rs +++ b/roaring/src/bitmap/store/array_store/vector.rs @@ -548,3 +548,118 @@ pub fn swizzle_to_front(val: u16x8, bitmask: u8) -> u16x8 { let swizzled: u8x16 = val_convert.swizzle_dyn(swizzle_idxs); u16x8::from_ne_bytes(swizzled) } + +#[inline] +pub fn quad_contains(slice: &[u16], val: u16) -> bool { + const GAP: usize = u16x8::LEN * 2; + + let (chunks, remaining) = slice.as_chunks::(); + + if chunks.is_empty() { + return match remaining.iter().copied().find(|v| *v >= val) { + Some(v) => v == val, + None => false, + }; + } + + let num_blocks = chunks.len(); + let mut base = 0; + let mut n = num_blocks; + while n > 3 { + let quarter = n >> 2; // equivalent to n / 4 + + let k1 = chunks[base + quarter][GAP - 1]; + let k2 = chunks[base + 2 * quarter][GAP - 1]; + let k3 = chunks[base + 3 * quarter][GAP - 1]; + + let c1 = (k1 < val) as usize; + let c2 = (k2 < val) as usize; + let c3 = (k3 < val) as usize; + + base += (c1 + c2 + c3) * quarter; + n -= 3 * quarter; + } + + while n > 1 { + let half = n >> 1; // equivalent to n / 2 + base = if chunks[base + half][GAP - 1] < val { base + half } else { base }; + n -= half; + } + + let lo = if chunks[base][GAP - 1] < val { base + 1 } else { base }; + + if lo < num_blocks { + let ndl = u16x8::splat(val); + // I would love to work with arrays here... + let v0 = u16x8::from_slice(&chunks[lo][..GAP / 2]); + let v1 = u16x8::from_slice(&chunks[lo][GAP / 2..]); + return (v0.simd_eq(ndl) | v1.simd_eq(ndl)).any(); + } + + match slice.iter().copied().skip(num_blocks * GAP).find(|v| *v >= val) { + Some(v) => v == val, + None => false, + } +} + +#[inline] +pub fn quad_search(slice: &[u16], val: u16) -> Result { + const GAP: usize = u16x8::LEN * 2; + + let (chunks, remaining) = slice.as_chunks::(); + + if chunks.is_empty() { + return match remaining.iter().copied().enumerate().find(|(_, v)| *v >= val) { + Some((i, v)) if v == val => Ok(i), + Some((i, _)) => Err(i), + None => Err(slice.len()), + }; + } + + let num_blocks = chunks.len(); + let mut base = 0; + let mut n = num_blocks; + while n > 3 { + let quarter = n >> 2; // equivalent to n / 4 + + let k1 = chunks[base + quarter][GAP - 1]; + let k2 = chunks[base + 2 * quarter][GAP - 1]; + let k3 = chunks[base + 3 * quarter][GAP - 1]; + + let c1 = (k1 < val) as usize; + let c2 = (k2 < val) as usize; + let c3 = (k3 < val) as usize; + + base += (c1 + c2 + c3) * quarter; + n -= 3 * quarter; + } + + while n > 1 { + let half = n >> 1; // equivalent to n / 2 + base = if chunks[base + half][GAP - 1] < val { base + half } else { base }; + n -= half; + } + + let lo = if chunks[base][GAP - 1] < val { base + 1 } else { base }; + + if lo < num_blocks { + let ndl = u16x8::splat(val); + // I would love to work with arrays here... + let v0 = u16x8::from_slice(&chunks[lo][..GAP / 2]); + let v1 = u16x8::from_slice(&chunks[lo][GAP / 2..]); + let base_index = lo * GAP; + return match (v0.simd_ge(ndl).first_set(), v1.simd_ge(ndl).first_set()) { + (Some(i), _) if v0[i] == val => Ok(base_index + i), + (Some(i), _) => Err(base_index + i), + (_, Some(i)) if v1[i] == val => Ok(base_index + GAP / 2 + i), + (_, Some(i)) => Err(base_index + GAP / 2 + i), + (None, None) => Err(slice.len()), + }; + } + + match slice.iter().copied().enumerate().skip(num_blocks * GAP).find(|(_, v)| *v >= val) { + Some((i, v)) if v == val => Ok(i), + Some((i, _)) => Err(i), + None => Err(slice.len()), + } +}