Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 36 additions & 21 deletions benches/cedarwood_benchmark.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use cedarwood::Cedar;
use criterion::{criterion_group, criterion_main, Criterion};
use criterion::{black_box, criterion_group, criterion_main, Criterion};

fn build_cedar() -> Cedar {
let dict = vec![
Expand All @@ -23,30 +23,45 @@ fn build_cedar() -> Cedar {
cedar
}

fn bench_cedar_build() {
let _cedar = build_cedar();
}
fn criterion_benchmark(c: &mut Criterion) {
c.bench_function("cedar build", |b| b.iter(|| black_box(build_cedar())));

fn bench_exact_match_search() {
let cedar = build_cedar();
let _ret = cedar.exact_match_search("中华人民");
}
c.bench_function("cedar exact_match_search", |b| {
let cedar = build_cedar();
b.iter(|| black_box(cedar.exact_match_search(black_box("中华人民"))))
});

fn bench_common_prefix_search() {
let cedar = build_cedar();
let _ret = cedar.common_prefix_search("中华人民");
}
c.bench_function("cedar common_prefix_search", |b| {
let cedar = build_cedar();
b.iter(|| black_box(cedar.common_prefix_search(black_box("中华人民"))))
});

fn bench_common_prefix_predict() {
let cedar = build_cedar();
let _ret = cedar.common_prefix_predict("中");
}
c.bench_function("cedar common_prefix_search (iter)", |b| {
let cedar = build_cedar();
b.iter(|| {
let mut count = 0i32;
for r in cedar.common_prefix_iter(black_box("中华人民")) {
count += r.0;
}
black_box(count)
})
});

fn criterion_benchmark(c: &mut Criterion) {
c.bench_function("cedar build", |b| b.iter(bench_cedar_build));
c.bench_function("cedar exact_match_search", |b| b.iter(bench_exact_match_search));
c.bench_function("cedar common_prefix_search", |b| b.iter(bench_common_prefix_search));
c.bench_function("cedar common_prefix_predict", |b| b.iter(bench_common_prefix_predict));
c.bench_function("cedar common_prefix_predict", |b| {
let cedar = build_cedar();
b.iter(|| black_box(cedar.common_prefix_predict(black_box("中"))))
});

c.bench_function("cedar common_prefix_predict (iter)", |b| {
let cedar = build_cedar();
b.iter(|| {
let mut count = 0i32;
for r in cedar.common_prefix_predict_iter(black_box("中")) {
count += r.0;
}
black_box(count)
})
});
}

criterion_group!(benches, criterion_benchmark);
Expand Down
168 changes: 115 additions & 53 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,15 @@ use smallvec::SmallVec;
use std::fmt;

/// NInfo stores the information about the trie
#[derive(Debug, Default, Clone)]
#[derive(Debug, Default, Clone, Copy)]
struct NInfo {
sibling: u8, // the index of right sibling, it is 0 if it doesn't have a sibling.
child: u8, // the index of the first child
}

/// Node contains the array of `base` and `check` as specified in the paper: "An efficient implementation of trie structures"
/// https://dl.acm.org/citation.cfm?id=146691
#[derive(Debug, Default, Clone)]
#[derive(Debug, Default, Clone, Copy)]
struct Node {
base_: i32, // if it is a negative value, then it stores the value of previous index that is free.
check: i32, // if it is a negative value, then it stores the value of next index that is free.
Expand Down Expand Up @@ -171,18 +171,38 @@ impl<'a> Iterator for PrefixIter<'a> {

fn next(&mut self) -> Option<Self::Item> {
while self.i < self.key.len() {
if let Some(value) = self.cedar.find(&self.key[self.i..=self.i], &mut self.from) {
if value == CEDAR_NO_VALUE {
self.i += 1;
continue;
} else {
let result = Some((value, self.i));
self.i += 1;
return result;
// Inline the single-byte traversal instead of calling find() with a 1-byte slice
let from = self.from;

#[cfg(feature = "reduced-trie")]
{
if self.cedar.array[from].base_ >= 0 {
break;
}
} else {
}

let base = self.cedar.array[from].base();
let to = (base ^ i32::from(self.key[self.i])) as usize;
if self.cedar.array[to].check != (from as i32) {
break;
}

self.from = to;
self.i += 1;

// Check for value at this position
#[cfg(feature = "reduced-trie")]
{
if self.cedar.array[to].base_ >= 0 {
return Some((self.cedar.array[to].base_, self.i - 1));
}
}

let terminal_base = self.cedar.array[to].base();
let terminal = &self.cedar.array[terminal_base as usize];
if terminal.check == (to as i32) && terminal.base_ != CEDAR_NO_VALUE {
return Some((terminal.base_, self.i - 1));
}
}

None
Expand Down Expand Up @@ -246,7 +266,7 @@ impl Cedar {
/// Initialize the Cedar for further use.
pub fn new() -> Self {
let mut array: Vec<Node> = Vec::with_capacity(256);
let n_infos: Vec<NInfo> = (0..256).map(|_| Default::default()).collect();
let n_infos: Vec<NInfo> = vec![NInfo::default(); 256];
let mut blocks: Vec<Block> = vec![Block::new(); 1];
let reject: Vec<i16> = (0..=256).map(|i| i + 1).collect();

Expand Down Expand Up @@ -284,6 +304,32 @@ impl Cedar {
}
}

/// SAFETY: `i` must be a valid index into `self.array`. Callers must ensure the trie's
/// structural invariants hold. Debug builds will panic on out-of-bounds access.
#[inline(always)]
unsafe fn node_unchecked(&self, i: usize) -> &Node {
debug_assert!(
i < self.array.len(),
"node_unchecked: index {} out of bounds (len {})",
i,
self.array.len()
);
self.array.get_unchecked(i)
}

/// SAFETY: `i` must be a valid index into `self.n_infos`. Callers must ensure the trie's
/// structural invariants hold. Debug builds will panic on out-of-bounds access.
#[inline(always)]
unsafe fn ninfo_unchecked(&self, i: usize) -> &NInfo {
debug_assert!(
i < self.n_infos.len(),
"ninfo_unchecked: index {} out of bounds (len {})",
i,
self.n_infos.len()
);
self.n_infos.get_unchecked(i)
}

/// Build the double array trie from the given key value pairs
pub fn build(&mut self, key_values: &[(&str, i32)]) {
for (key, value) in key_values {
Expand Down Expand Up @@ -368,20 +414,25 @@ impl Cedar {
}

// Find key from double array trie, with `from` as the cursor to traverse the nodes.
//
// SAFETY: The inner loop uses unchecked indexing for performance (this loop runs once per byte
// of the key and is the primary query bottleneck). Indices are valid by trie structural
// invariants: `from` is always a valid node, and `base ^ label` stays within the allocated
// array. Debug builds retain bounds checks via debug_assert in the helpers.
#[inline]
fn find(&self, key: &[u8], from: &mut usize) -> Option<i32> {
let mut pos = 0;

// recursively matching the key.
while pos < key.len() {
#[cfg(feature = "reduced-trie")]
{
if self.array[*from].base_ >= 0 {
if unsafe { self.node_unchecked(*from) }.base_ >= 0 {
break;
}
}

let to = (self.array[*from].base() ^ i32::from(key[pos])) as usize;
if self.array[to].check != (*from as i32) {
let to = (unsafe { self.node_unchecked(*from) }.base() ^ i32::from(key[pos])) as usize;
if unsafe { self.node_unchecked(to) }.check != (*from as i32) {
return None;
}

Expand Down Expand Up @@ -444,12 +495,12 @@ impl Cedar {
let mut e = self.array[from].base();

loop {
let n = self.array[from].clone();
let has_sibling = self.n_infos[(n.base() ^ i32::from(self.n_infos[from].child)) as usize].sibling != 0;
let base = self.array[from].base();
let has_sibling = self.n_infos[(base ^ i32::from(self.n_infos[from].child)) as usize].sibling != 0;

// if the node has siblings, then remove `e` from the sibling.
if has_sibling {
self.pop_sibling(from as i32, n.base(), (n.base() ^ e) as u8);
self.pop_sibling(from as i32, base, (base ^ e) as u8);
}

// maintain the data structures.
Expand Down Expand Up @@ -529,65 +580,74 @@ impl Cedar {
}

// To get the cursor of the first leaf node starting by `from`
//
// SAFETY: Uses unchecked indexing throughout. All indices are valid by trie structural
// invariants: `from` is always a valid node, and `base ^ child` stays within the array.
// This function is called per-result in common_prefix_predict and is performance-critical.
#[inline]
fn begin(&self, mut from: usize, mut p: usize) -> (Option<i32>, usize, usize) {
let base = self.array[from].base();
let mut c = self.n_infos[from].child;
let mut c = unsafe { self.ninfo_unchecked(from) }.child;

if from == 0 {
c = self.n_infos[(base ^ i32::from(c)) as usize].sibling;
let base = unsafe { self.node_unchecked(0) }.base();
c = unsafe { self.ninfo_unchecked((base ^ i32::from(c)) as usize) }.sibling;

// if no sibling couldn be found from the virtual root, then we are done.
if c == 0 {
return (None, from, p);
}
}

// recursively traversing down to look for the first leaf.
while c != 0 {
from = (self.array[from].base() ^ i32::from(c)) as usize;
c = self.n_infos[from].child;
from = (unsafe { self.node_unchecked(from) }.base() ^ i32::from(c)) as usize;
c = unsafe { self.ninfo_unchecked(from) }.child;
p += 1;
}

let node = unsafe { self.node_unchecked(from) };

#[cfg(feature = "reduced-trie")]
{
if self.array[from].base_ >= 0 {
return (Some(self.array[from].base_), from, p);
}
if node.base_ >= 0 {
return (Some(node.base_), from, p);
}

// To return the value of the leaf.
let v = self.array[(self.array[from].base() ^ i32::from(c)) as usize].base_;
let v = unsafe { self.node_unchecked(node.base() as usize) }.base_;
(Some(v), from, p)
}

// To move the cursor from one leaf to the next for the common_prefix_predict.
//
// SAFETY: Uses unchecked indexing throughout. All indices are valid by trie structural
// invariants: `from` is a valid node, and `check` always points to a valid parent.
// This function is called per-result in common_prefix_predict and is performance-critical.
#[inline]
fn next(&self, mut from: usize, mut p: usize, root: usize) -> (Option<i32>, usize, usize) {
#[cfg(feature = "reduced-trie")]
let mut c: u8 = if self.array[from].base_ < 0 {
self.n_infos[(self.array[from].base()) as usize].sibling
} else {
0
let mut c: u8 = {
let node = unsafe { self.node_unchecked(from) };
if node.base_ < 0 {
unsafe { self.ninfo_unchecked(node.base() as usize) }.sibling
} else {
0
}
};

#[cfg(not(feature = "reduced-trie"))]
let mut c: u8 = self.n_infos[(self.array[from].base()) as usize].sibling;
let mut c: u8 = {
let base = unsafe { self.node_unchecked(from) }.base();
unsafe { self.ninfo_unchecked(base as usize) }.sibling
};

// traversing up until there is a sibling or it has reached the root.
while c == 0 && from != root {
c = self.n_infos[from].sibling;
from = self.array[from].check as usize;
c = unsafe { self.ninfo_unchecked(from) }.sibling;
from = unsafe { self.node_unchecked(from) }.check as usize;

p -= 1;
}

if c != 0 {
// it has a sibling so we leverage on `begin` to traverse the subtree down again.
from = (self.array[from].base() ^ i32::from(c)) as usize;
let (v_, from_, p_) = self.begin(from, p + 1);
(v_, from_, p_)
from = (unsafe { self.node_unchecked(from) }.base() ^ i32::from(c)) as usize;
self.begin(from, p + 1)
} else {
// no more work since we couldn't find anything.
(None, from, p)
}
}
Expand All @@ -604,12 +664,13 @@ impl Cedar {
if last {
*head = 0;
} else {
let b = self.blocks[idx as usize].clone();
self.blocks[b.prev as usize].next = b.next;
self.blocks[b.next as usize].prev = b.prev;
let b_prev = self.blocks[idx as usize].prev;
let b_next = self.blocks[idx as usize].next;
self.blocks[b_prev as usize].next = b_next;
self.blocks[b_next as usize].prev = b_prev;

if idx == *head {
*head = b.next;
*head = b_next;
}
}
}
Expand Down Expand Up @@ -697,7 +758,8 @@ impl Cedar {
};

let idx = e >> 8;
let n = self.array[e as usize].clone();
let n_base = self.array[e as usize].base_;
let n_check = self.array[e as usize].check;

self.blocks[idx as usize].num -= 1;
// move the block at idx to the correct linked-list depending the free slots it still have.
Expand All @@ -706,11 +768,11 @@ impl Cedar {
self.transfer_block(idx, BlockType::Closed, BlockType::Full, self.blocks_head_full == 0);
}
} else {
self.array[(-n.base_) as usize].check = n.check;
self.array[(-n.check) as usize].base_ = n.base_;
self.array[(-n_base) as usize].check = n_check;
self.array[(-n_check) as usize].base_ = n_base;

if e == self.blocks[idx as usize].e_head {
self.blocks[idx as usize].e_head = -n.check;
self.blocks[idx as usize].e_head = -n_check;
}

if idx != 0 && self.blocks[idx as usize].num == 1 && self.blocks[idx as usize].trial != self.max_trial {
Expand Down
Loading