diff --git a/src/experimental/zipper_algebra.rs b/src/experimental/zipper_algebra.rs index 994ec6d..679e3e9 100644 --- a/src/experimental/zipper_algebra.rs +++ b/src/experimental/zipper_algebra.rs @@ -1,13 +1,17 @@ +use pathmap_derive::PolyZipperExplicit; + use crate::{ alloc::{Allocator, GlobalAlloc}, ring::{AlgebraicResult, COUNTER_IDENT, DistributiveLattice, Lattice, SELF_IDENT}, utils::ByteMask, zipper::{ - ReadZipperUntracked, Zipper, ZipperInfallibleSubtries, ZipperMoving, ZipperValues, - ZipperWriting, + ReadZipperTracked, ReadZipperUntracked, Zipper, ZipperInfallibleSubtries, ZipperMoving, + ZipperValues, ZipperWriting, }, }; +pub use zipper_algebra_poly::ZipperMergeF; + /// Extension trait providing algebraic merge operations on radix-256 trie zippers. /// /// This trait exposes high-level operations such as [`join`](Self::join), @@ -94,6 +98,11 @@ impl ZipperAlgebraExt { } +impl ZipperAlgebraExt + for ReadZipperTracked<'_, '_, V, A> +{ +} + /// Performs an ordered join (least upper bound) of two radix-256 tries using zipper traversal. /// /// This function merges two tries by simultaneously traversing them in lexicographic order, @@ -343,9 +352,24 @@ trait MergePolicy { trait ValuePolicy { fn combine(l: Option<&V>, r: Option<&V>) -> Option; + #[inline] + fn combine_acc(l: Option, r: Option<&V>) -> Option { + Self::combine(l.as_ref(), r) + } fn combine3(l: Option<&V>, m: Option<&V>, r: Option<&V>) -> Option { // (l op m) op r - Self::combine(Self::combine(l, m).as_ref(), r) + Self::combine_acc(Self::combine(l, m), r) + } + fn combine4(a: Option<&V>, b: Option<&V>, c: Option<&V>, d: Option<&V>) -> Option { + // ((a op b) op c) op d + Self::combine_acc(Self::combine_acc(Self::combine(a, b), c), d) + } + fn combine_n<'a, I>(vals: I) -> Option + where + I: Iterator>, + V: 'a, + { + vals.fold(None, |acc, v| Self::combine_acc(acc, v)) } } @@ -366,8 +390,8 @@ where let mut k = 0; let mut lhs_mask = lhs.child_mask(); let mut rhs_mask = rhs.child_mask(); - let mut lhs_idx = 0; - let mut rhs_idx = 0; + let mut lhs_next = lhs_mask.indexed_bit::(0); + let mut rhs_next = rhs_mask.indexed_bit::(0); // At each node, the algorithm treats the sets of child edges of `lhs` and `rhs` as two sorted // sequences and performs a merge-like traversal: @@ -380,18 +404,15 @@ where // and an explicit depth counter (`k`). 'ascend: loop { 'merge_level: loop { - let lhs_next = lhs_mask.indexed_bit::(lhs_idx as usize); - let rhs_next = rhs_mask.indexed_bit::(rhs_idx as usize); - match lhs_next { Some(lhs_byte) => match rhs_next { Some(rhs_byte) if lhs_byte < rhs_byte => { P::on_left_only(lhs, ByteMask::from_range(lhs_byte..rhs_byte), out); - lhs_idx = lhs_mask.index_of(rhs_byte); + lhs_next = (lhs_mask & ByteMask::from_range(rhs_byte..)).next_bit(0); } Some(rhs_byte) if lhs_byte > rhs_byte => { P::on_right_only(rhs, ByteMask::from_range(rhs_byte..lhs_byte), out); - rhs_idx = rhs_mask.index_of(lhs_byte); + rhs_next = (rhs_mask & ByteMask::from_range(lhs_byte..)).next_bit(0); } Some(rhs_byte) => { // equal → descend @@ -407,8 +428,8 @@ where lhs_mask = lhs.child_mask(); rhs_mask = rhs.child_mask(); - lhs_idx = 0; - rhs_idx = 0; + lhs_next = lhs_mask.indexed_bit::(0); + rhs_next = rhs_mask.indexed_bit::(0); k += 1; continue 'merge_level; @@ -437,11 +458,11 @@ where rhs.ascend_byte(); rhs_mask = rhs.child_mask(); - rhs_idx = rhs_mask.index_of(byte_from) + 1; + rhs_next = rhs_mask.next_bit(byte_from); lhs.ascend_byte(); lhs_mask = lhs.child_mask(); - lhs_idx = lhs_mask.index_of(byte_from) + 1; + lhs_next = lhs_mask.next_bit(byte_from); out.ascend_byte(); k -= 1; @@ -508,16 +529,12 @@ where let mut lhs_mask = lhs.child_mask(); let mut mid_mask = mid.child_mask(); let mut rhs_mask = rhs.child_mask(); - let mut lhs_idx = 0; - let mut mid_idx = 0; - let mut rhs_idx = 0; + let mut l = lhs_mask.indexed_bit::(0); + let mut m = mid_mask.indexed_bit::(0); + let mut r = rhs_mask.indexed_bit::(0); 'ascend: loop { 'merge_level: loop { - let l = lhs_mask.indexed_bit::(lhs_idx as usize); - let m = mid_mask.indexed_bit::(mid_idx as usize); - let r = rhs_mask.indexed_bit::(rhs_idx as usize); - let mut a = l; let mut b = m; let mut c = r; @@ -543,7 +560,7 @@ where cmp_swap(&mut b, &mut c); if let Some(next) = b { P::on_single(lhs, L as u64, ByteMask::from_range(min..next), out); - lhs_idx = lhs_mask.index_of(next) + l = (lhs_mask & ByteMask::from_range(next..)).next_bit(0); } else { P::on_single(lhs, L as u64, ByteMask::from_range(min..), out); break 'merge_level; @@ -553,7 +570,7 @@ where cmp_swap(&mut b, &mut c); if let Some(next) = b { P::on_single(mid, M as u64, ByteMask::from_range(min..next), out); - mid_idx = mid_mask.index_of(next); + m = (mid_mask & ByteMask::from_range(next..)).next_bit(0); } else { P::on_single(mid, M as u64, ByteMask::from_range(min..), out); break 'merge_level; @@ -563,7 +580,7 @@ where cmp_swap(&mut b, &mut c); if let Some(next) = b { P::on_single(rhs, R as u64, ByteMask::from_range(min..next), out); - rhs_idx = rhs_mask.index_of(next); + r = (rhs_mask & ByteMask::from_range(next..)).next_bit(0); } else { P::on_single(rhs, R as u64, ByteMask::from_range(min..), out); break 'merge_level; @@ -574,22 +591,22 @@ where if P::descend_on_some_equal(LM as u64) { descend2::(min, lhs, mid, out); } - lhs_idx += 1; - mid_idx += 1; + l = lhs_mask.next_bit(min); + m = mid_mask.next_bit(min); } MR => { if P::descend_on_some_equal(MR as u64) { descend2::(min, mid, rhs, out); } - mid_idx += 1; - rhs_idx += 1; + m = mid_mask.next_bit(min); + r = rhs_mask.next_bit(min); } LR => { if P::descend_on_some_equal(LR as u64) { descend2::(min, lhs, rhs, out); } - lhs_idx += 1; - rhs_idx += 1; + l = lhs_mask.next_bit(min); + r = rhs_mask.next_bit(min); } // full 3-way LMR => { @@ -607,9 +624,9 @@ where mid_mask = mid.child_mask(); rhs_mask = rhs.child_mask(); - lhs_idx = 0; - mid_idx = 0; - rhs_idx = 0; + l = lhs_mask.indexed_bit::(0); + m = mid_mask.indexed_bit::(0); + r = rhs_mask.indexed_bit::(0); k += 1; continue 'merge_level; @@ -630,15 +647,565 @@ where rhs.ascend_byte(); rhs_mask = rhs.child_mask(); - rhs_idx = rhs_mask.index_of(byte_from) + 1; + r = rhs_mask.next_bit(byte_from); mid.ascend_byte(); mid_mask = mid.child_mask(); - mid_idx = mid_mask.index_of(byte_from) + 1; + m = mid_mask.next_bit(byte_from); lhs.ascend_byte(); lhs_mask = lhs.child_mask(); - lhs_idx = lhs_mask.index_of(byte_from) + 1; + l = lhs_mask.next_bit(byte_from); + + out.ascend_byte(); + k -= 1; + } +} + +// semi-unrolled (bitmask-driven) +// Beyond 4, the combinatorics start to creak, but k = 4 is a sweet spot: +// - still manageable (16 frontier cases) +// - still branch-predictable +// - still worth it for hot paths +fn zipper_merge4( + z0: &mut Z0, + z1: &mut Z1, + z2: &mut Z2, + z3: &mut Z3, + out: &mut Out, +) where + V: Clone + Send + Sync, + P: MergePolicy + ValuePolicy, + A: Allocator, + Z0: ZipperInfallibleSubtries + ZipperMoving, + Z1: ZipperInfallibleSubtries + ZipperMoving, + Z2: ZipperInfallibleSubtries + ZipperMoving, + Z3: ZipperInfallibleSubtries + ZipperMoving, + Out: ZipperWriting, +{ + // merge root values before descending + if let Some(v) = P::combine4(z0.val(), z1.val(), z2.val(), z3.val()) { + out.set_val(v); + } + + let mut k = 0; + // state (fully unrolled) + let mut m0 = z0.child_mask(); + let mut m1 = z1.child_mask(); + let mut m2 = z2.child_mask(); + let mut m3 = z3.child_mask(); + + let mut b0 = m0.indexed_bit::(0); + let mut b1 = m1.indexed_bit::(0); + let mut b2 = m2.indexed_bit::(0); + let mut b3 = m3.indexed_bit::(0); + + 'ascend: loop { + 'merge_level: loop { + // min selection + let mut a = b0; + let mut b = b1; + let mut c = b2; + let mut d = b3; + + cmp_swap(&mut a, &mut b); + cmp_swap(&mut a, &mut c); + cmp_swap(&mut a, &mut d); + + if let Some(min) = a { + let mut frontier = 0u8; + if b0 == a { + frontier |= 0b0001; + } + if b1 == a { + frontier |= 0b0010; + } + if b2 == a { + frontier |= 0b0100; + } + if b3 == a { + frontier |= 0b1000; + } + + // full match + if frontier == 0b1111 { + out.descend_to_byte(min); + + z0.descend_to_byte(min); + z1.descend_to_byte(min); + z2.descend_to_byte(min); + z3.descend_to_byte(min); + + if let Some(v) = P::combine4(z0.val(), z1.val(), z2.val(), z3.val()) { + out.set_val(v); + } + + m0 = z0.child_mask(); + b0 = m0.indexed_bit::(0); + m1 = z1.child_mask(); + b1 = m1.indexed_bit::(0); + m2 = z2.child_mask(); + b2 = m2.indexed_bit::(0); + m3 = z3.child_mask(); + b3 = m3.indexed_bit::(0); + + k += 1; + continue 'merge_level; + } + + let cnt = frontier.count_ones(); + // singleton + if cnt == 1 { + cmp_swap(&mut b, &mut c); + cmp_swap(&mut b, &mut d); + + match frontier { + 0b0001 => { + if let Some(next) = b { + P::on_single(z0, 0b0001, ByteMask::from_range(min..next), out); + b0 = (m0 & ByteMask::from_range(next..)).next_bit(0); + } else { + P::on_single(z0, 0b0001, ByteMask::from_range(min..), out); + break 'merge_level; + } + } + 0b0010 => { + if let Some(next) = b { + P::on_single(z1, 0b0010, ByteMask::from_range(min..next), out); + b1 = (m1 & ByteMask::from_range(next..)).next_bit(0); + } else { + P::on_single(z1, 0b0010, ByteMask::from_range(min..), out); + break 'merge_level; + } + } + 0b0100 => { + if let Some(next) = b { + P::on_single(z2, 0b0100, ByteMask::from_range(min..next), out); + b2 = (m2 & ByteMask::from_range(next..)).next_bit(0); + } else { + P::on_single(z2, 0b0100, ByteMask::from_range(min..), out); + break 'merge_level; + } + } + 0b1000 => { + if let Some(next) = b { + P::on_single(z3, 0b1000, ByteMask::from_range(min..next), out); + b3 = (m3 & ByteMask::from_range(next..)).next_bit(0); + } else { + P::on_single(z3, 0b1000, ByteMask::from_range(min..), out); + break 'merge_level; + } + } + _ => unreachable!(), + }; + } else { + // partial overlap (2 or 3) + + // avoid 16 match arms and duplicated logic + if P::descend_on_some_equal(frontier as u64) { + out.descend_to_byte(min); + + if frontier & 0b0001 != 0 { + z0.descend_to_byte(min); + } + if frontier & 0b0010 != 0 { + z1.descend_to_byte(min); + } + if frontier & 0b0100 != 0 { + z2.descend_to_byte(min); + } + if frontier & 0b1000 != 0 { + z3.descend_to_byte(min); + } + + // recurse on subset (still using 4-way function, but inactive ones won't match) + if (cnt == 2) { + let i = frontier.trailing_zeros(); + let j = (frontier & !(1 << i)).trailing_zeros(); + match (i, j) { + (0, 1) => { + zipper_merge::(z0, z1, out); + } + (0, 2) => { + zipper_merge::(z0, z2, out); + } + (0, 3) => { + zipper_merge::(z0, z3, out); + } + (1, 2) => { + zipper_merge::(z1, z2, out); + } + (1, 3) => { + zipper_merge::(z1, z3, out); + } + (2, 3) => { + zipper_merge::(z2, z3, out); + } + _ => unreachable!(), + } + } else { + // cnt == 3 + let mut bits = frontier; + let i = bits.trailing_zeros(); + bits &= bits - 1; // trick: it removes the lowest bit set from a bitmask + let j = bits.trailing_zeros(); + bits &= bits - 1; + let k = bits.trailing_zeros(); + match (i, j, k) { + (0, 1, 2) => { + zipper_merge3::(z0, z1, z2, out); + } + (0, 1, 3) => { + zipper_merge3::(z0, z1, z3, out); + } + (0, 2, 3) => { + zipper_merge3::(z0, z2, z3, out); + } + (1, 2, 3) => { + zipper_merge3::(z1, z2, z3, out); + } + _ => unreachable!(), + } + } + + if frontier & 0b0001 != 0 { + z0.ascend_byte(); + } + if frontier & 0b0010 != 0 { + z1.ascend_byte(); + } + if frontier & 0b0100 != 0 { + z2.ascend_byte(); + } + if frontier & 0b1000 != 0 { + z3.ascend_byte(); + } + + out.ascend_byte(); + } + // then advance + if frontier & 0b0001 != 0 { + b0 = m0.next_bit(min); + } + if frontier & 0b0010 != 0 { + b1 = m1.next_bit(min); + } + if frontier & 0b0100 != 0 { + b2 = m2.next_bit(min); + } + if frontier & 0b1000 != 0 { + b3 = m3.next_bit(min); + } + } + } else { + break 'merge_level; + } + } + + // If we are at root and no deeper recursion pending, we're done + if k == 0 { + break 'ascend; + } + + let byte_from = *z0.path().last().expect("non-empty path when k > 0"); + + z0.ascend_byte(); + m0 = z0.child_mask(); + b0 = m0.next_bit(byte_from); + + z1.ascend_byte(); + m1 = z1.child_mask(); + b1 = m1.next_bit(byte_from); + + z2.ascend_byte(); + m2 = z2.child_mask(); + b2 = m2.next_bit(byte_from); + + z3.ascend_byte(); + m3 = z3.child_mask(); + b3 = m3.next_bit(byte_from); + + out.ascend_byte(); + k -= 1; + } +} + +use zipper_algebra_poly::SomeMutRefZ as Z; + +// - The function is fully monomorphized over `N` and uses a bitmask (`active`) +// to track participating zippers. +// - Small frontiers (`k ≤ 4`) are dispatched to specialized implementations +// for improved performance. +// - Requires `N ≤ 64`. +fn zipper_merge_n_mono( + zs: &mut [Z<'_, '_, '_, V, A>; N], + active: u64, + out: &mut Out, +) where + V: Clone + Send + Sync + Unpin, + P: MergePolicy + ValuePolicy, + A: Allocator, + Out: ZipperWriting, +{ + debug_assert!(N > 0 && N <= 64); + // LLVM needs to prove: `0 ≤ i < N` But i comes from: `i = bits.trailing_zeros() as usize;`` So the + // compiler must connect: “this bitmask only contains bits < N” + assert!(active >> N == 0); + + #[inline] + fn active_bits(active: u64) -> impl Iterator { + (0..N).filter(move |i| (active >> i) & 1 != 0) + } + + fn zippers<'a, 'trie, 'path, V, A, const N: usize>( + zs: &'a [Z<'a, 'trie, 'path, V, A>; N], + active: u64, + ) -> impl Iterator)> + where + V: Clone + Send + Sync + Unpin, + A: Allocator, + { + active_bits::(active).map(|i| (i, &zs[i])) + } + + fn values<'a, V, A, const N: usize>( + zs: &'a [Z<'a, '_, '_, V, A>; N], + active: u64, + ) -> impl Iterator> + where + V: Clone + Send + Sync + Unpin, + A: Allocator, + { + zippers(zs, active).map(|(_, z)| z.val()) + } + + // small micro-helpers + #[inline(always)] + fn for_each_bit(mut bits: u64, mut f: impl FnMut(usize)) { + while bits != 0 { + let i = bits.trailing_zeros() as usize; + bits &= bits - 1; + f(i); + } + } + + #[inline(always)] + fn with_k( + xs: &mut [T], + mut bits: u64, + f: impl FnOnce([&mut T; K]) -> R, + ) -> R { + debug_assert!(bits.count_ones() as usize >= K); + + // collect raw pointers first (safe) + let mut ptrs: [*mut T; K] = [std::ptr::null_mut(); K]; + + let mut i = 0; + while i < K { + let idx = bits.trailing_zeros() as usize; + bits &= bits - 1; + ptrs[i] = unsafe { xs.as_mut_ptr().add(idx) }; + i += 1; + } + + // SAFETY: + // - indices are distinct (bitmask) + // - derived from same slice + + // should be zero-cost after inlining + let refs = unsafe { ptrs.map(|p| &mut *p) }; + + f(refs) + } + + // combine root values + if let Some(v) = P::combine_n(values(zs, active)) { + out.set_val(v); + } + + let mut bytes = [None; N]; + let mut masks = [ByteMask::EMPTY; N]; + for (i, z) in zippers(zs, active) { + masks[i] = z.child_mask(); + bytes[i] = masks[i].indexed_bit::(0); + } + + // At each node, the algorithm: + // + // - Treats the child edges of all active zippers as sorted byte sequences, + // - Computes the minimal byte `a` across all inputs, + // - Forms the *frontier* — the subset of zippers containing `a`, + // - Dispatches based on frontier size: + // + // - **Full match (`frontier == active`)** + // Descends into all zippers without recursion (fast path). + // + // - **Singleton (`|frontier| = 1`)** + // Grafts the corresponding subtrie directly into the output. + // + // - **Partial overlap (`1 < |frontier| < N`)** + // Optionally descends into the subset, dispatching to specialized + // implementations for small arities (`k ≤ 4`) or recursively invoking + // this function on the subset. + // + // The traversal is performed iteratively using zipper movements + // (`descend_to_byte` / `ascend_byte`) and an explicit depth counter, + // avoiding recursion in the common case. + let mut k = 0; + debug_assert!(active.count_ones() > 0); + 'ascend: loop { + 'merge_level: loop { + let mut min = None; + let mut frontier = 0u64; + let mut next = None; + + for i in active_bits::(active) { + if let Some(b) = bytes[i] { + match min { + None => { + min = Some(b); + frontier = 1 << i; + } + Some(m) if b < m => { + next = Some(m); + min = Some(b); + frontier = 1 << i; + } + Some(m) if b == m => { + frontier |= 1 << i; + } + Some(m) => { + next = match next { + old @ Some(n) if n <= b => old, + _ => Some(b), + }; + } + } + } + } + + debug_assert!(frontier <= active); + + match min { + None => { + break 'merge_level; + } + Some(a) => { + // Dispatch + + // - Case A: full match (frontier == all bits) + if frontier == active { + out.descend_to_byte(a); + + // descend and refresh masks and indices + for_each_bit(active, |i| { + let mut z = &mut zs[i]; + z.descend_to_byte(a); + masks[i] = z.child_mask(); + bytes[i] = masks[i].indexed_bit::(0); + }); + + if let Some(v) = P::combine_n(values(zs, active)) { + out.set_val(v); + } + + k += 1; + continue 'merge_level; + } + + let cnt = frontier.count_ones(); + // - Case B: singleton (|frontier| = 1) + if (cnt == 1) { + let i = frontier.trailing_zeros() as usize; + match next { + None => { + P::on_single(&mut zs[i], frontier, ByteMask::from_range(a..), out); + break 'merge_level; + } + Some(b) => { + P::on_single(&mut zs[i], frontier, ByteMask::from_range(a..b), out); + // advance + bytes[i] = (masks[i] & ByteMask::from_range(b..)).next_bit(0); + } + } + } else { + // - Case C: subset (1 < k < N) + if P::descend_on_some_equal(frontier) { + out.descend_to_byte(a); + match cnt { + 2 => with_k::<2, _, _>(zs, frontier, |[lhs, rhs]| { + lhs.descend_to_byte(a); + rhs.descend_to_byte(a); + + zipper_merge::(lhs, rhs, out); + + rhs.ascend_byte(); + lhs.ascend_byte(); + }), + 3 => with_k::<3, _, _>(zs, frontier, |[lhs, mid, rhs]| { + lhs.descend_to_byte(a); + mid.descend_to_byte(a); + rhs.descend_to_byte(a); + + zipper_merge3::(lhs, mid, rhs, out); + + rhs.ascend_byte(); + mid.ascend_byte(); + lhs.ascend_byte(); + }), + 4 => with_k::<4, _, _>(zs, frontier, |[z0, z1, z2, z3]| { + z0.descend_to_byte(a); + z1.descend_to_byte(a); + z2.descend_to_byte(a); + z3.descend_to_byte(a); + + zipper_merge4::(z0, z1, z2, z3, out); + + z3.ascend_byte(); + z2.ascend_byte(); + z1.ascend_byte(); + z0.ascend_byte(); + }), + _ => { + // descend all active in the frontier + for_each_bit(frontier, |i| zs[i].descend_to_byte(a)); + + // recursive call with SAME array, smaller mask + zipper_merge_n_mono::(zs, frontier, out); + + //ascend + for_each_bit(frontier, |i| { + zs[i].ascend_byte(); + }); + } + } + + out.ascend_byte(); + } + + // advance indices + for_each_bit(frontier, |i| { + bytes[i] = masks[i].next_bit(a); + }); + } + } + } + } + + if (k == 0) { + break 'ascend; + } + + let i0 = active.trailing_zeros() as usize; + let byte_from = *zs[i0].path().last().expect("non-empty path when k > 0"); + + // ascend + for_each_bit(active, |i| { + let mut z = &mut zs[i]; + z.ascend_byte(); + masks[i] = z.child_mask(); + bytes[i] = masks[i].next_bit(byte_from); + }); out.ascend_byte(); k -= 1; @@ -687,6 +1254,28 @@ impl ValuePolicy for Join { r.cloned() } } + + fn combine_acc(l: Option, r: Option<&V>) -> Option { + if let Some(lv) = l { + if let Some(rv) = r { + match lv.pjoin(rv) { + AlgebraicResult::None => None, + AlgebraicResult::Identity(mask) => { + if mask & SELF_IDENT != 0 { + Some(lv) + } else { + r.cloned() + } + } + AlgebraicResult::Element(v) => Some(v), + } + } else { + Some(lv) + } + } else { + r.cloned() + } + } } // ==================== MEET ==================== @@ -716,6 +1305,27 @@ impl ValuePolicy for Meet { fn combine3(l: Option<&V>, m: Option<&V>, r: Option<&V>) -> Option { l.and_then(|x| m.and_then(|y| r.and_then(|z| meet_acc(meet_refs(x, y)?, z)))) } + + fn combine4(a: Option<&V>, b: Option<&V>, c: Option<&V>, d: Option<&V>) -> Option { + a.and_then(|w| { + b.and_then(|x| { + c.and_then(|y| d.and_then(|z| meet_acc(meet_acc(meet_refs(w, x)?, y)?, z))) + }) + }) + } + + fn combine_n<'a, I>(vals: I) -> Option + where + I: Iterator>, + V: 'a, + { + let mut it = vals; + let z = it.next()?.cloned()?; + it.try_fold(z, |acc, v| { + let rv = v?; + meet_acc(acc, rv) + }) + } } #[inline] @@ -792,23 +1402,385 @@ impl ValuePolicy for Subtract { fn combine(l: Option<&V>, r: Option<&V>) -> Option { l.and_then(|lv| { if let Some(rv) = r { - match lv.psubtract(rv) { - AlgebraicResult::None => None, - AlgebraicResult::Identity(mask) => { - if mask & SELF_IDENT != 0 { - Some(lv.clone()) - } else { - None - } - } - AlgebraicResult::Element(v) => Some(v), - } + subtract_refs(lv, rv) } else { // lhs-only → keep Some(lv.clone()) } }) } + + fn combine_n<'a, I>(vals: I) -> Option + where + I: Iterator>, + V: 'a, + { + let mut it = vals; + let z = it.next()?.cloned()?; + it.try_fold(z, |acc, v| match v { + Some(rv) => subtract_acc(acc, rv), + None => Some(acc), + }) + } + + fn combine_acc(l: Option, r: Option<&V>) -> Option { + l.and_then(|lv| { + if let Some(rv) = r { + subtract_acc(lv, rv) + } else { + // lhs-only → keep + Some(lv) + } + }) + } +} + +#[inline] +fn subtract_refs(a: &V, b: &V) -> Option { + match a.psubtract(b) { + AlgebraicResult::None => None, + AlgebraicResult::Identity(mask) => { + if mask & SELF_IDENT != 0 { + Some(a.clone()) + } else { + None + } + } + AlgebraicResult::Element(v) => Some(v), + } +} +#[inline] +fn subtract_acc(a: V, b: &V) -> Option { + match a.psubtract(b) { + AlgebraicResult::None => None, + AlgebraicResult::Identity(mask) => { + if mask & SELF_IDENT != 0 { + Some(a) + } else { + None + } + } + AlgebraicResult::Element(v) => Some(v), + } +} + +mod zipper_algebra_poly { + // ==================== Machinery for zipper_merge_n ==================== + use crate as pathmap; + use crate::PathMap; + use crate::alloc::Allocator; + use crate::ring::{DistributiveLattice, Lattice}; + use crate::trie_node::*; + use crate::zipper::*; + + #[derive(PolyZipperExplicit)] + #[poly_zipper_explicit(traits(ZipperMoving, ZipperValues))] + pub(super) enum SomeMutRefZ<'a, 'trie, 'path, V: Clone + Send + Sync + Unpin, A: Allocator> { + RZ(&'a mut ReadZipperUntracked<'trie, 'path, V, A>), + RZT(&'a mut ReadZipperTracked<'trie, 'path, V, A>), + } + + impl ZipperInfallibleSubtries + for SomeMutRefZ<'_, '_, '_, V, A> + { + fn make_map(&self) -> PathMap { + match self { + SomeMutRefZ::RZ(inner) => inner.make_map(), + SomeMutRefZ::RZT(inner) => inner.make_map(), + } + } + + fn get_trie_ref(&self) -> TrieRef<'_, V, A> { + match self { + SomeMutRefZ::RZ(inner) => inner.get_trie_ref(), + SomeMutRefZ::RZT(inner) => inner.get_trie_ref(), + } + } + + fn get_focus(&self) -> OpaqueAbstractNodeRef<'_, V, A> { + match self { + SomeMutRefZ::RZ(inner) => inner.get_focus(), + SomeMutRefZ::RZT(inner) => inner.get_focus(), + } + } + + fn try_borrow_focus(&self) -> Option> { + match self { + SomeMutRefZ::RZ(inner) => inner.try_borrow_focus(), + SomeMutRefZ::RZT(inner) => inner.try_borrow_focus(), + } + } + } + + pub trait ZipperMergeF + where + V: Clone + Send + Sync, + A: Allocator, + Self: Sized, + { + /// Performs an N-way ordered join (least upper bound) of radix-256 trie zippers using a stackless traversal. + /// + /// This function generalizes pairwise [`super::zipper_join`] to an arbitrary number of input tries, + fn join_n(self, out: &mut Out) + where + V: Lattice, + { + self.merge_n::(out); + } + + /// Performs an N-way ordered meet(greeatest lower bound) of radix-256 trie zippers using a stackless traversal. + /// + /// This function generalizes pairwise [`super::zipper_meet`] to an arbitrary number of input tries, + fn meet_n(self, out: &mut Out) + where + V: Lattice, + { + self.merge_n::(out); + } + + /// Performs an N-way ordered subtraction (left-associative) of radix-256 trie zippers using a stackless traversal. + /// + /// This function generalizes pairwise [`super::zipper_subtract`] to an arbitrary number of input tries, + fn subtract_n(self, out: &mut Out) + where + V: DistributiveLattice, + { + self.merge_n::(out); + } + + fn merge_n

(self, out: &mut Out) + where + P: super::MergePolicy + super::ValuePolicy; + } + + impl ZipperMergeF for (&mut Z1, &mut Z2) + where + V: Clone + Send + Sync, + A: Allocator, + Z1: ZipperInfallibleSubtries + ZipperMoving, + Z2: ZipperInfallibleSubtries + ZipperMoving, + Out: ZipperWriting, + { + fn merge_n

(mut self, out: &mut Out) + where + P: super::MergePolicy + super::ValuePolicy, + { + super::zipper_merge::(self.0, self.1, out); + } + } + + impl ZipperMergeF for (&mut Z1, &mut Z2, &mut Z3) + where + V: Clone + Send + Sync, + A: Allocator, + Z1: ZipperInfallibleSubtries + ZipperMoving, + Z2: ZipperInfallibleSubtries + ZipperMoving, + Z3: ZipperInfallibleSubtries + ZipperMoving, + Out: ZipperWriting, + { + fn merge_n

(mut self, out: &mut Out) + where + P: super::MergePolicy + super::ValuePolicy, + { + super::zipper_merge3::(self.0, self.1, self.2, out); + } + } + + impl ZipperMergeF for (&mut Z1, &mut Z2, &mut Z3, &mut Z4) + where + V: Clone + Send + Sync, + A: Allocator, + Z1: ZipperInfallibleSubtries + ZipperMoving, + Z2: ZipperInfallibleSubtries + ZipperMoving, + Z3: ZipperInfallibleSubtries + ZipperMoving, + Z4: ZipperInfallibleSubtries + ZipperMoving, + Out: ZipperWriting, + { + fn merge_n

(mut self, out: &mut Out) + where + P: super::MergePolicy + super::ValuePolicy, + { + super::zipper_merge4::(self.0, self.1, self.2, self.3, out); + } + } + + macro_rules! impl_zipper_merge_f { + ($($Z:ident),+) => { + impl<'trie, 'path, V, $($Z),+, Out, A> ZipperMergeF + for ($( &mut $Z ),+) + where + V: Clone + Send + Sync + Unpin + 'trie, + A: Allocator + 'trie, + $( + for<'x> &'x mut $Z: Into>, + )+ + Out: ZipperWriting, + { + fn merge_n

(mut self, out: &mut Out) + where + P: super::MergePolicy + super::ValuePolicy, + { + // destructure the tuple + let ($( $Z ),+) = self; + + let mut zs = [ + $( $Z.into() ),+ + ]; + + let active: u64 = (1 << zs.len()) - 1; + + super::zipper_merge_n_mono::( + &mut zs, + active, + out, + ); + } + } + }; +} + + impl_zipper_merge_f!(Z1, Z2, Z3, Z4, Z5); + impl_zipper_merge_f!(Z1, Z2, Z3, Z4, Z5, Z6); + impl_zipper_merge_f!(Z1, Z2, Z3, Z4, Z5, Z6, Z7); + impl_zipper_merge_f!(Z1, Z2, Z3, Z4, Z5, Z6, Z7, Z8); + impl_zipper_merge_f!(Z1, Z2, Z3, Z4, Z5, Z6, Z7, Z8, Z9); + impl_zipper_merge_f!(Z1, Z2, Z3, Z4, Z5, Z6, Z7, Z8, Z9, Z10); + impl_zipper_merge_f!(Z1, Z2, Z3, Z4, Z5, Z6, Z7, Z8, Z9, Z10, Z11); + impl_zipper_merge_f!(Z1, Z2, Z3, Z4, Z5, Z6, Z7, Z8, Z9, Z10, Z11, Z12); + impl_zipper_merge_f!(Z1, Z2, Z3, Z4, Z5, Z6, Z7, Z8, Z9, Z10, Z11, Z12, Z13); + impl_zipper_merge_f!(Z1, Z2, Z3, Z4, Z5, Z6, Z7, Z8, Z9, Z10, Z11, Z12, Z13, Z14); + impl_zipper_merge_f!( + Z1, Z2, Z3, Z4, Z5, Z6, Z7, Z8, Z9, Z10, Z11, Z12, Z13, Z14, Z15 + ); + impl_zipper_merge_f!( + Z1, Z2, Z3, Z4, Z5, Z6, Z7, Z8, Z9, Z10, Z11, Z12, Z13, Z14, Z15, Z16 + ); + impl_zipper_merge_f!( + Z1, Z2, Z3, Z4, Z5, Z6, Z7, Z8, Z9, Z10, Z11, Z12, Z13, Z14, Z15, Z16, Z17 + ); + impl_zipper_merge_f!( + Z1, Z2, Z3, Z4, Z5, Z6, Z7, Z8, Z9, Z10, Z11, Z12, Z13, Z14, Z15, Z16, Z17, Z18 + ); + impl_zipper_merge_f!( + Z1, Z2, Z3, Z4, Z5, Z6, Z7, Z8, Z9, Z10, Z11, Z12, Z13, Z14, Z15, Z16, Z17, Z18, Z19 + ); + impl_zipper_merge_f!( + Z1, Z2, Z3, Z4, Z5, Z6, Z7, Z8, Z9, Z10, Z11, Z12, Z13, Z14, Z15, Z16, Z17, Z18, Z19, Z20 + ); + impl_zipper_merge_f!( + Z1, Z2, Z3, Z4, Z5, Z6, Z7, Z8, Z9, Z10, Z11, Z12, Z13, Z14, Z15, Z16, Z17, Z18, Z19, Z20, + Z21 + ); + impl_zipper_merge_f!( + Z1, Z2, Z3, Z4, Z5, Z6, Z7, Z8, Z9, Z10, Z11, Z12, Z13, Z14, Z15, Z16, Z17, Z18, Z19, Z20, + Z21, Z22 + ); + impl_zipper_merge_f!( + Z1, Z2, Z3, Z4, Z5, Z6, Z7, Z8, Z9, Z10, Z11, Z12, Z13, Z14, Z15, Z16, Z17, Z18, Z19, Z20, + Z21, Z22, Z23 + ); + impl_zipper_merge_f!( + Z1, Z2, Z3, Z4, Z5, Z6, Z7, Z8, Z9, Z10, Z11, Z12, Z13, Z14, Z15, Z16, Z17, Z18, Z19, Z20, + Z21, Z22, Z23, Z24 + ); + impl_zipper_merge_f!( + Z1, Z2, Z3, Z4, Z5, Z6, Z7, Z8, Z9, Z10, Z11, Z12, Z13, Z14, Z15, Z16, Z17, Z18, Z19, Z20, + Z21, Z22, Z23, Z24, Z25 + ); + impl_zipper_merge_f!( + Z1, Z2, Z3, Z4, Z5, Z6, Z7, Z8, Z9, Z10, Z11, Z12, Z13, Z14, Z15, Z16, Z17, Z18, Z19, Z20, + Z21, Z22, Z23, Z24, Z25, Z26 + ); + impl_zipper_merge_f!( + Z1, Z2, Z3, Z4, Z5, Z6, Z7, Z8, Z9, Z10, Z11, Z12, Z13, Z14, Z15, Z16, Z17, Z18, Z19, Z20, + Z21, Z22, Z23, Z24, Z25, Z26, Z27 + ); + impl_zipper_merge_f!( + Z1, Z2, Z3, Z4, Z5, Z6, Z7, Z8, Z9, Z10, Z11, Z12, Z13, Z14, Z15, Z16, Z17, Z18, Z19, Z20, + Z21, Z22, Z23, Z24, Z25, Z26, Z27, Z28 + ); + impl_zipper_merge_f!( + Z1, Z2, Z3, Z4, Z5, Z6, Z7, Z8, Z9, Z10, Z11, Z12, Z13, Z14, Z15, Z16, Z17, Z18, Z19, Z20, + Z21, Z22, Z23, Z24, Z25, Z26, Z27, Z28, Z29 + ); + impl_zipper_merge_f!( + Z1, Z2, Z3, Z4, Z5, Z6, Z7, Z8, Z9, Z10, Z11, Z12, Z13, Z14, Z15, Z16, Z17, Z18, Z19, Z20, + Z21, Z22, Z23, Z24, Z25, Z26, Z27, Z28, Z29, Z30 + ); + impl_zipper_merge_f!( + Z1, Z2, Z3, Z4, Z5, Z6, Z7, Z8, Z9, Z10, Z11, Z12, Z13, Z14, Z15, Z16, Z17, Z18, Z19, Z20, + Z21, Z22, Z23, Z24, Z25, Z26, Z27, Z28, Z29, Z30, Z31 + ); + impl_zipper_merge_f!( + Z1, Z2, Z3, Z4, Z5, Z6, Z7, Z8, Z9, Z10, Z11, Z12, Z13, Z14, Z15, Z16, Z17, Z18, Z19, Z20, + Z21, Z22, Z23, Z24, Z25, Z26, Z27, Z28, Z29, Z30, Z31, Z32 + ); + + /// Performs an N-ary zipper join by borrowing all inputs mutably + /// and forwarding them to [`ZipperMergeF::join_n`]. + /// + /// # Example + /// ``` + /// zipper_join_n!(z1, z2, z3 => out); + /// ``` + /// + /// Expands roughly to: + /// ```ignore + /// (&mut z1, &mut z2, &mut z3).join_n(&mut out) + /// ``` + /// + /// # See also + /// [`ZipperMergeF::join_n`] + #[macro_export] + macro_rules! zipper_join_n { + ( $($z:ident),+ => $out:ident ) => {{ + ( $( &mut $z ),+ ).join_n(&mut $out) + }}; +} + + /// Performs an N-ary zipper meet by borrowing all inputs mutably + /// and forwarding them to [`ZipperMergeF::meet_n`]. + /// + /// # Example + /// ``` + /// zipper_meet_n!(z1, z2, z3 => out); + /// ``` + /// + /// Expands roughly to: + /// ```ignore + /// (&mut z1, &mut z2, &mut z3).meet_n(&mut out) + /// ``` + /// + /// # See also + /// [`ZipperMergeF::meet_n`] + #[macro_export] + macro_rules! zipper_meet_n { + ( $($z:ident),+ => $out:ident ) => {{ + ( $( &mut $z ),+ ).meet_n(&mut $out) + }}; +} + + /// Performs an N-ary zipper subtract by borrowing all inputs mutably + /// and forwarding them to [`ZipperMergeF::subtract_n`]. + /// + /// # Example + /// ``` + /// zipper_subtract_n!(z1, z2, z3 => out); + /// ``` + /// + /// Expands roughly to: + /// ```ignore + /// (&mut z1, &mut z2, &mut z3).subtract_n(&mut out) + /// ``` + /// + /// # See also + /// [`ZipperMergeF::subtract_n`] + #[macro_export] + macro_rules! zipper_subtract_n { + ( $($z:ident),+ => $out:ident ) => {{ + ( $( &mut $z ),+ ).subtract_n(&mut $out) + }}; +} } #[cfg(test)] @@ -821,6 +1793,8 @@ mod tests { type Paths = &'static [(&'static [u8], u64)]; type BinaryTest = (Paths, Paths); type TernaryTest = (Paths, Paths, Paths); + type NaryTest = [Paths; N]; + const N: usize = 6; fn mk_binary_test(test: &BinaryTest) -> (PathMap, PathMap) { (PathMap::from_iter(test.0), PathMap::from_iter(test.1)) @@ -834,6 +1808,10 @@ mod tests { ) } + fn mk_nary_test(test: &NaryTest) -> [PathMap; N] { + test.map(PathMap::from_iter) + } + fn check2< 'x, T: IntoIterator, @@ -888,6 +1866,27 @@ mod tests { assert_trie(expected, result); } + fn checkn< + 'x, + T: IntoIterator, + F: for<'a> FnOnce([ReadZipperUntracked<'a, 'x, u64>; N], WriteZipperUntracked<'a, 'x, u64>), + >( + test: &NaryTest, + expected: T, + op: F, + ) { + let path_maps = mk_nary_test(test); + + let mut result = PathMap::new(); + + op( + path_maps.each_ref().map(PathMap::read_zipper), + result.write_zipper(), + ); + + assert_trie(expected, result); + } + fn assert_trie<'a, T: IntoIterator>( expected: T, result: PathMap, @@ -949,7 +1948,46 @@ mod tests { (&[0xFF, 0x00, 0x00], 2), (&[0xFF, 0x00, 0x00, 0x00], 3), ], - ); + ); + + const DISJOINT_PATHS_N: NaryTest = [ + &[ + (&[0x00], 0), + (&[0x00, 0x00], 1), + (&[0x00, 0x00, 0x00], 2), + (&[0x00, 0x00, 0x00, 0x00], 3), + ], + &[ + (&[0xC0], 0), + (&[0xC0, 0x00], 1), + (&[0xC0, 0x00, 0x00], 2), + (&[0xC0, 0x00, 0x00, 0x00], 3), + ], + &[ + (&[0xD0], 0), + (&[0xD0, 0x00], 1), + (&[0xD0, 0x00, 0x00], 2), + (&[0xD0, 0x00, 0x00, 0x00], 3), + ], + &[ + (&[0xE0], 0), + (&[0xE0, 0x00], 1), + (&[0xE0, 0x00, 0x00], 2), + (&[0xE0, 0x00, 0x00, 0x00], 3), + ], + &[ + (&[0xF0], 0), + (&[0xF0, 0x00], 1), + (&[0xF0, 0x00, 0x00], 2), + (&[0xF0, 0x00, 0x00, 0x00], 3), + ], + &[ + (&[0xFF], 0), + (&[0xFF, 0x00], 1), + (&[0xFF, 0x00, 0x00], 2), + (&[0xFF, 0x00, 0x00, 0x00], 3), + ], + ]; const PATHS_WITH_SHARED_PREFIX: BinaryTest = ( &[(b"aaaaa0", 0), (b"bbbbbbbb0", 1)], @@ -962,6 +2000,15 @@ mod tests { &[(b"aaaaa2", 0), (b"bbbbb2", 1), (b"bbbbbbbb2", 2)], ); + const PATHS_WITH_SHARED_PREFIX_N: NaryTest = [ + &[(b"aaaaa0", 0), (b"bbbbbbbb0", 1)], + &[(b"aaaaa1", 0), (b"bbbbb1", 1), (b"bbbbbbbb1", 2)], + &[(b"aaaaa2", 0), (b"bbbbb2", 1), (b"bbbbbbbb2", 2)], + &[(b"aaaaa3", 0), (b"bbbbb3", 1), (b"bbbbbbbb3", 2)], + &[(b"aaaaa4", 0), (b"bbbbb4", 1), (b"bbbbbbbb4", 2)], + &[(b"aaaaa5", 0), (b"bbbbb5", 1), (b"bbbbbbbb5", 2)], + ]; + const INTERLEAVING_PATHS: BinaryTest = ( &[(&[0], 0), (&[2], 1), (&[4], 2), (&[6], 3)], &[(&[1], 0), (&[3], 1), (&[5], 2), (&[7], 3)], @@ -973,6 +2020,15 @@ mod tests { &[(&[2], 0), (&[5], 1), (&[8], 2), (&[11], 3)], ); + const INTERLEAVING_PATHS_N: NaryTest = [ + &[(&[0], 0), (&[6], 1), (&[12], 2), (&[18], 3)], + &[(&[1], 0), (&[7], 1), (&[13], 2), (&[19], 3)], + &[(&[2], 0), (&[8], 1), (&[14], 2), (&[20], 3)], + &[(&[3], 0), (&[9], 1), (&[15], 2), (&[21], 3)], + &[(&[4], 0), (&[10], 1), (&[16], 2), (&[22], 3)], + &[(&[5], 0), (&[11], 1), (&[17], 2), (&[23], 3)], + ]; + const ONE_SIDED_PATHS: BinaryTest = ( &[ (&[0x00], 0), @@ -1014,6 +2070,30 @@ mod tests { &[(&[0x00], 0), (&[0x00, 0x01, 0x02], 1)], ); + const ONE_SIDED_PATHS_N: NaryTest = [ + &[ + (&[0x00], 0), + (&[0x00, 0x01], 1), + (&[0x00, 0x01, 0x02], 2), + (&[0x00, 0x01, 0x02, 0x03], 3), + (&[0x01], 4), + (&[0x01, 0x02], 5), + (&[0x01, 0x02, 0x03], 6), + (&[0x01, 0x02, 0x03, 0x04], 7), + (&[0x01, 0x02, 0x03, 0x04, 0x05], 8), + (&[0x01, 0x02, 0x03, 0x04, 0x05, 0x06], 9), + ], + &[ + (&[0x00], 0), + (&[0x00, 0x01, 0x02, 0x03], 1), + (&[0x01, 0x02, 0x03, 0x04, 0x05], 2), + ], + &[(&[0x00], 0), (&[0x00, 0x01, 0x02], 1)], + &[(&[0x01], 4)], + &[(&[0x01], 4), (&[0x01, 0x02], 5)], + &[(&[0x01], 4), (&[0x01, 0x02, 0x03], 5)], + ]; + const ALMOST_IDENTICAL_PATHS: BinaryTest = ( &[ (b"abcdefg", 0), @@ -1071,13 +2151,94 @@ mod tests { ], ); + const ALMOST_IDENTICAL_PATHS_N: NaryTest = [ + &[ + (b"abcdefg", 0), + (b"hijklmnop", 1), + (b"qrstuwvxyz", 2), + (b"0", 3), + (b"1", 4), + (b"2", 5), + (b"3", 6), + (b"4", 7), + (b"5", 8), + (b"6789", 9), + ], + &[ + (b"abcdefg", 0), + (b"qrstuwvxyz", 2), + (b"0", 3), + (b"1", 4), + (b"4", 7), + (b"5", 8), + (b"6789", 9), + ], + &[ + (b"abcdefg", 0), + (b"hijklmnop", 1), + (b"1", 4), + (b"2", 5), + (b"3", 6), + (b"4", 7), + (b"5", 8), + ], + &[ + (b"hijklmnop", 1), + (b"1", 4), + (b"2", 5), + (b"3", 6), + (b"4", 7), + (b"5", 8), + ], + &[ + (b"hijklmnop", 1), + (b"1", 4), + (b"2", 5), + (b"3", 6), + (b"4", 7), + (b"5", 8), + ], + &[ + (b"hijklmnop", 1), + (b"1", 4), + (b"2", 5), + (b"3", 6), + (b"4", 7), + (b"5", 8), + ], + ]; + const LHS_EMPTY: BinaryTest = (&[], &[(&[1], 0), (&[2], 1)]); const LHS_EMPTY_3: TernaryTest = (&[], &[(&[1], 0), (&[2], 1)], &[(&[3], 0), (&[4], 1)]); + const LHS_EMPTY_N: NaryTest = [ + &[], + &[(&[1], 0), (&[2], 1)], + &[(&[3], 0), (&[4], 1)], + &[(&[5], 0)], + &[(&[6], 0)], + &[(&[7], 0), (&[8], 1)], + ]; const RHS_EMPTY: BinaryTest = (&[(&[1], 0), (&[2], 1)], &[]); const RHS_EMPTY_3: TernaryTest = (&[(&[1], 0), (&[2], 1)], &[(&[3], 0), (&[4], 1)], &[]); + const RHS_EMPTY_N: NaryTest = [ + &[(&[1], 0), (&[2], 1)], + &[(&[3], 0), (&[4], 1)], + &[(&[5], 0)], + &[(&[6], 0)], + &[(&[7], 0), (&[8], 1)], + &[], + ]; const MID_EMPTY: TernaryTest = (&[(&[1], 0), (&[2], 1)], &[], &[(&[3], 0), (&[4], 1)]); + const MID_EMPTY_N: NaryTest = [ + &[(&[1], 0), (&[2], 1)], + &[(&[3], 0), (&[4], 1)], + &[], + &[(&[5], 0)], + &[(&[6], 0)], + &[(&[7], 0), (&[8], 1)], + ]; const PATHS_WITH_SAME_PREFIX_DIFFERENT_CHILDREN: BinaryTest = ( &[ @@ -1110,6 +2271,39 @@ mod tests { ], ); + const PATHS_WITH_SAME_PREFIX_DIFFERENT_CHILDREN_N: NaryTest = [ + &[ + (&[1, 2, 3], 0), + (&[1, 2, 3, 4], 1), + (&[1, 2, 3, 10, 11, 12], 2), + ], + &[ + (&[1, 2, 3], 10), + (&[1, 2, 3, 5], 11), + (&[1, 2, 3, 10, 11, 0], 12), + ], + &[ + (&[1, 2, 3], 20), + (&[1, 2, 3, 6], 21), + (&[1, 2, 3, 10, 11, 1], 22), + ], + &[ + (&[1, 2, 3], 30), + (&[1, 2, 3, 7], 31), + (&[1, 2, 3, 10, 11, 2], 32), + ], + &[ + (&[1, 2, 3], 40), + (&[1, 2, 3, 8], 41), + (&[1, 2, 3, 10, 11, 3], 42), + ], + &[ + (&[1, 2, 3], 50), + (&[1, 2, 3, 9], 51), + (&[1, 2, 3, 10, 11, 4], 52), + ], + ]; + const ZIGZAG_PATHS: BinaryTest = ( &[ (&[1, 1], 0), @@ -1167,9 +2361,21 @@ mod tests { &[(&[], 3), (&[1], 30)], ); + const PATHS_WITH_ROOT_VALS_AND_CHILDREN_N: NaryTest = [ + &[(&[], 1), (&[1], 10), (&[2], 110)], + &[(&[], 2), (&[1], 20), (&[2], 120)], + &[(&[], 3), (&[1], 30), (&[2], 130)], + &[(&[], 4), (&[1], 40), (&[2], 140)], + &[(&[], 5), (&[1], 50), (&[2], 150)], + &[(&[], 6), (&[1], 60), (&[2], 160)], + ]; + mod join { use super::*; - use crate::experimental::zipper_algebra::{ZipperAlgebraExt, zipper_join, zipper_join3}; + use crate::experimental::zipper_algebra::{ + ZipperAlgebraExt, ZipperMergeF, zipper_join, zipper_join3, + }; + use crate::zipper_join_n; #[test] fn test_disjoint() { @@ -1189,6 +2395,23 @@ mod tests { ); } + #[test] + fn test_disjoint_n() { + checkn( + &DISJOINT_PATHS_N, + &[ + DISJOINT_PATHS_N[0], + DISJOINT_PATHS_N[1], + DISJOINT_PATHS_N[2], + DISJOINT_PATHS_N[3], + DISJOINT_PATHS_N[4], + DISJOINT_PATHS_N[5], + ] + .concat(), + |[mut z0, mut z1, mut z2, mut z3, mut z4, mut z5], mut out| zipper_join_n!(z0, z1, z2, z3, z4, z5 => out), + ); + } + #[test] fn test_deep_shared_prefix_then_split() { check2( @@ -1212,6 +2435,23 @@ mod tests { ); } + #[test] + fn test_deep_shared_prefix_then_split_n() { + checkn( + &PATHS_WITH_SHARED_PREFIX_N, + &[ + PATHS_WITH_SHARED_PREFIX_N[0], + PATHS_WITH_SHARED_PREFIX_N[1], + PATHS_WITH_SHARED_PREFIX_N[2], + PATHS_WITH_SHARED_PREFIX_N[3], + PATHS_WITH_SHARED_PREFIX_N[4], + PATHS_WITH_SHARED_PREFIX_N[5], + ] + .concat(), + |[mut z0, mut z1, mut z2, mut z3, mut z4, mut z5], mut out| zipper_join_n!(z0, z1, z2, z3, z4, z5 => out), + ); + } + #[test] fn test_interleaving_paths() { check2( @@ -1235,6 +2475,23 @@ mod tests { ); } + #[test] + fn test_interleaving_paths_n() { + checkn( + &INTERLEAVING_PATHS_N, + &[ + INTERLEAVING_PATHS_N[0], + INTERLEAVING_PATHS_N[1], + INTERLEAVING_PATHS_N[2], + INTERLEAVING_PATHS_N[3], + INTERLEAVING_PATHS_N[4], + INTERLEAVING_PATHS_N[5], + ] + .concat(), + |[mut z0, mut z1, mut z2, mut z3, mut z4, mut z5], mut out| zipper_join_n!(z0, z1, z2, z3, z4, z5 => out), + ); + } + #[test] fn test_one_side_empty_at_many_levels() { check2(&ONE_SIDED_PATHS, ONE_SIDED_PATHS.0, |lhs, rhs, out| { @@ -1251,6 +2508,15 @@ mod tests { ); } + #[test] + fn test_one_side_empty_at_many_levels_n() { + checkn( + &ONE_SIDED_PATHS_N, + ONE_SIDED_PATHS_N[0], + |[mut z0, mut z1, mut z2, mut z3, mut z4, mut z5], mut out| zipper_join_n!(z0, z1, z2, z3, z4, z5 => out), + ); + } + #[test] fn test_almost_identical_paths() { check2( @@ -1269,6 +2535,15 @@ mod tests { ); } + #[test] + fn test_almost_identical_paths_n() { + checkn( + &ALMOST_IDENTICAL_PATHS_N, + ALMOST_IDENTICAL_PATHS_N[0], + |[mut z0, mut z1, mut z2, mut z3, mut z4, mut z5], mut out| zipper_join_n!(z0, z1, z2, z3, z4, z5 => out), + ); + } + #[test] fn test_one_side_empty() { check2(&LHS_EMPTY, LHS_EMPTY.1, |lhs, rhs, out| lhs.join(rhs, out)); @@ -1294,6 +2569,46 @@ mod tests { ); } + #[test] + fn test_one_side_empty_n() { + checkn( + &LHS_EMPTY_N, + &[ + LHS_EMPTY_N[1], + LHS_EMPTY_N[2], + LHS_EMPTY_N[3], + LHS_EMPTY_N[4], + LHS_EMPTY_N[5], + ] + .concat(), + |[mut z0, mut z1, mut z2, mut z3, mut z4, mut z5], mut out| zipper_join_n!(z0, z1, z2, z3, z4, z5 => out), + ); + checkn( + &MID_EMPTY_N, + &[ + MID_EMPTY_N[0], + MID_EMPTY_N[1], + MID_EMPTY_N[3], + MID_EMPTY_N[4], + MID_EMPTY_N[5], + ] + .concat(), + |[mut z0, mut z1, mut z2, mut z3, mut z4, mut z5], mut out| zipper_join_n!(z0, z1, z2, z3, z4, z5 => out), + ); + checkn( + &RHS_EMPTY_N, + &[ + RHS_EMPTY_N[0], + RHS_EMPTY_N[1], + RHS_EMPTY_N[2], + RHS_EMPTY_N[3], + RHS_EMPTY_N[4], + ] + .concat(), + |[mut z0, mut z1, mut z2, mut z3, mut z4, mut z5], mut out| zipper_join_n!(z0, z1, z2, z3, z4, z5 => out), + ); + } + #[test] fn test_exact_overlap_divergent_subtries() { let expected: Paths = &[ @@ -1328,6 +2643,30 @@ mod tests { ); } + #[test] + fn test_exact_overlap_divergent_subtries_n() { + let expected: Paths = &[ + (&[1, 2, 3], 0), + (&[1, 2, 3, 4], 1), + (&[1, 2, 3, 5], 11), + (&[1, 2, 3, 6], 21), + (&[1, 2, 3, 7], 31), + (&[1, 2, 3, 8], 41), + (&[1, 2, 3, 9], 51), + (&[1, 2, 3, 10, 11, 0], 12), + (&[1, 2, 3, 10, 11, 1], 22), + (&[1, 2, 3, 10, 11, 2], 32), + (&[1, 2, 3, 10, 11, 3], 42), + (&[1, 2, 3, 10, 11, 4], 52), + (&[1, 2, 3, 10, 11, 12], 2), + ]; + checkn( + &PATHS_WITH_SAME_PREFIX_DIFFERENT_CHILDREN_N, + expected, + |[mut z0, mut z1, mut z2, mut z3, mut z4, mut z5], mut out| zipper_join_n!(z0, z1, z2, z3, z4, z5 => out), + ); + } + #[test] fn test_zigzag() { check2( @@ -1341,7 +2680,7 @@ mod tests { fn test_zigzag3() { check3( &ZIGZAG_PATHS_3, - &[ZIGZAG_PATHS.0, ZIGZAG_PATHS.1, ZIGZAG_PATHS_3.2].concat(), + &[ZIGZAG_PATHS_3.0, ZIGZAG_PATHS_3.1, ZIGZAG_PATHS_3.2].concat(), |lhs, mid, rhs, out| zipper_join3(lhs, mid, rhs, out), ); } @@ -1363,11 +2702,23 @@ mod tests { |lhs, mid, rhs, out| zipper_join3(lhs, mid, rhs, out), ); } + + #[test] + fn test_root_values_n() { + checkn( + &PATHS_WITH_ROOT_VALS_AND_CHILDREN_N, + PATHS_WITH_ROOT_VALS_AND_CHILDREN_N[0], + |[mut z0, mut z1, mut z2, mut z3, mut z4, mut z5], mut out| zipper_join_n!(z0, z1, z2, z3, z4, z5 => out), + ); + } } mod meet { use super::*; - use crate::experimental::zipper_algebra::{ZipperAlgebraExt, zipper_meet, zipper_meet3}; + use crate::experimental::zipper_algebra::{ + ZipperAlgebraExt, ZipperMergeF, zipper_meet, zipper_meet3, + }; + use crate::zipper_meet_n; #[test] fn test_disjoint() { @@ -1383,6 +2734,15 @@ mod tests { }); } + #[test] + fn test_disjoint_n() { + checkn( + &DISJOINT_PATHS_N, + [], + |[mut z0, mut z1, mut z2, mut z3, mut z4, mut z5], mut out| zipper_meet_n!(z0, z1, z2, z3, z4, z5 => out), + ); + } + #[test] fn test_deep_shared_prefix_then_split() { check2(&PATHS_WITH_SHARED_PREFIX, [], |lhs, rhs, out| { @@ -1397,6 +2757,15 @@ mod tests { }); } + #[test] + fn test_deep_shared_prefix_then_split_n() { + checkn( + &PATHS_WITH_SHARED_PREFIX_N, + [], + |[mut z0, mut z1, mut z2, mut z3, mut z4, mut z5], mut out| zipper_meet_n!(z0, z1, z2, z3, z4, z5 => out), + ); + } + #[test] fn test_interleaving_paths() { check2(&INTERLEAVING_PATHS, [], |lhs, rhs, out| { @@ -1411,6 +2780,15 @@ mod tests { }); } + #[test] + fn test_interleaving_paths_n() { + checkn( + &INTERLEAVING_PATHS_N, + [], + |[mut z0, mut z1, mut z2, mut z3, mut z4, mut z5], mut out| zipper_meet_n!(z0, z1, z2, z3, z4, z5 => out), + ); + } + #[test] fn test_one_side_empty_at_many_levels() { let expected: Paths = &[ @@ -1431,6 +2809,15 @@ mod tests { }); } + #[test] + fn test_one_side_empty_at_many_levels_n() { + checkn( + &ONE_SIDED_PATHS_N, + [], + |[mut z0, mut z1, mut z2, mut z3, mut z4, mut z5], mut out| zipper_meet_n!(z0, z1, z2, z3, z4, z5 => out), + ); + } + #[test] fn test_almost_identical_paths() { check2( @@ -1448,6 +2835,16 @@ mod tests { }); } + #[test] + fn test_almost_identical_paths_n() { + let expected: Paths = &[(b"1", 4), (b"4", 7), (b"5", 8)]; + checkn( + &ALMOST_IDENTICAL_PATHS_N, + expected, + |[mut z0, mut z1, mut z2, mut z3, mut z4, mut z5], mut out| zipper_meet_n!(z0, z1, z2, z3, z4, z5 => out), + ); + } + #[test] fn test_one_side_empty() { check2(&LHS_EMPTY, [], |lhs, rhs, out| lhs.meet(rhs, out)); @@ -1467,6 +2864,25 @@ mod tests { }); } + #[test] + fn test_one_side_empty_n() { + checkn( + &LHS_EMPTY_N, + [], + |[mut z0, mut z1, mut z2, mut z3, mut z4, mut z5], mut out| zipper_meet_n!(z0, z1, z2, z3, z4, z5 => out), + ); + checkn( + &MID_EMPTY_N, + [], + |[mut z0, mut z1, mut z2, mut z3, mut z4, mut z5], mut out| zipper_meet_n!(z0, z1, z2, z3, z4, z5 => out), + ); + checkn( + &RHS_EMPTY_N, + [], + |[mut z0, mut z1, mut z2, mut z3, mut z4, mut z5], mut out| zipper_meet_n!(z0, z1, z2, z3, z4, z5 => out), + ); + } + #[test] fn test_exact_overlap_divergent_subtries() { let expected: Paths = &[(&[1, 2, 3], 0)]; @@ -1487,6 +2903,16 @@ mod tests { ); } + #[test] + fn test_exact_overlap_divergent_subtries_n() { + let expected: Paths = &[(&[1, 2, 3], 0)]; + checkn( + &PATHS_WITH_SAME_PREFIX_DIFFERENT_CHILDREN_N, + expected, + |[mut z0, mut z1, mut z2, mut z3, mut z4, mut z5], mut out| zipper_meet_n!(z0, z1, z2, z3, z4, z5 => out), + ); + } + #[test] fn test_zigzag() { let expected: Paths = &[(&[2, 1], 2), (&[3], 3)]; @@ -1520,13 +2946,23 @@ mod tests { |lhs, mid, rhs, out| zipper_meet3(lhs, mid, rhs, out), ); } + + #[test] + fn test_root_values_n() { + checkn( + &PATHS_WITH_ROOT_VALS_AND_CHILDREN_N, + PATHS_WITH_ROOT_VALS_AND_CHILDREN_N[0], + |[mut z0, mut z1, mut z2, mut z3, mut z4, mut z5], mut out| zipper_meet_n!(z0, z1, z2, z3, z4, z5 => out), + ); + } } mod subtract { use super::*; use crate::experimental::zipper_algebra::{ - ZipperAlgebraExt, zipper_subtract, zipper_subtract3, + ZipperAlgebraExt, ZipperMergeF, zipper_subtract, zipper_subtract3, }; + use crate::zipper_subtract_n; #[test] fn test_disjoint() { @@ -1546,6 +2982,15 @@ mod tests { ); } + #[test] + fn test_disjoint_n() { + checkn( + &DISJOINT_PATHS_N, + DISJOINT_PATHS_N[0], + |[mut z0, mut z1, mut z2, mut z3, mut z4, mut z5], mut out| zipper_subtract_n!(z0, z1, z2, z3, z4, z5 => out), + ); + } + #[test] fn test_deep_shared_prefix_then_split() { check2( @@ -1566,6 +3011,15 @@ mod tests { ); } + #[test] + fn test_deep_shared_prefix_then_split_n() { + checkn( + &PATHS_WITH_SHARED_PREFIX_N, + PATHS_WITH_SHARED_PREFIX_N[0], + |[mut z0, mut z1, mut z2, mut z3, mut z4, mut z5], mut out| zipper_subtract_n!(z0, z1, z2, z3, z4, z5 => out), + ); + } + #[test] fn test_interleaving_paths() { check2( @@ -1586,6 +3040,15 @@ mod tests { ); } + #[test] + fn test_interleaving_paths_n() { + checkn( + &INTERLEAVING_PATHS_N, + INTERLEAVING_PATHS_N[0], + |[mut z0, mut z1, mut z2, mut z3, mut z4, mut z5], mut out| zipper_subtract_n!(z0, z1, z2, z3, z4, z5 => out), + ); + } + #[test] fn test_one_side_empty_at_many_levels() { let expected: Paths = &[ @@ -1622,6 +3085,24 @@ mod tests { }); } + #[test] + fn test_one_side_empty_at_many_levels_n() { + let expected: Paths = &[ + (&[0x00, 0x01], 1), + (&[0x00, 0x01, 0x02], 2), + (&[0x00, 0x01, 0x02, 0x03], 3), + (&[0x01, 0x02, 0x03], 6), + (&[0x01, 0x02, 0x03, 0x04], 7), + (&[0x01, 0x02, 0x03, 0x04, 0x05], 8), + (&[0x01, 0x02, 0x03, 0x04, 0x05, 0x06], 9), + ]; + checkn( + &ONE_SIDED_PATHS_N, + expected, + |[mut z0, mut z1, mut z2, mut z3, mut z4, mut z5], mut out| zipper_subtract_n!(z0, z1, z2, z3, z4, z5 => out), + ); + } + #[test] fn test_almost_identical_paths() { let expected: Paths = &[(b"hijklmnop", 1), (b"2", 5), (b"3", 6)]; @@ -1637,6 +3118,15 @@ mod tests { }); } + #[test] + fn test_almost_identical_paths_n() { + checkn( + &ALMOST_IDENTICAL_PATHS_N, + [], + |[mut z0, mut z1, mut z2, mut z3, mut z4, mut z5], mut out| zipper_subtract_n!(z0, z1, z2, z3, z4, z5 => out), + ); + } + #[test] fn test_one_side_empty() { check2(&LHS_EMPTY, [], |lhs, rhs, out| lhs.subtract(rhs, out)); @@ -1658,6 +3148,25 @@ mod tests { }); } + #[test] + fn test_one_side_empty_n() { + checkn( + &LHS_EMPTY_N, + [], + |[mut z0, mut z1, mut z2, mut z3, mut z4, mut z5], mut out| zipper_subtract_n!(z0, z1, z2, z3, z4, z5 => out), + ); + checkn( + &MID_EMPTY_N, + MID_EMPTY_N[0], + |[mut z0, mut z1, mut z2, mut z3, mut z4, mut z5], mut out| zipper_subtract_n!(z0, z1, z2, z3, z4, z5 => out), + ); + checkn( + &RHS_EMPTY_N, + RHS_EMPTY_N[0], + |[mut z0, mut z1, mut z2, mut z3, mut z4, mut z5], mut out| zipper_subtract_n!(z0, z1, z2, z3, z4, z5 => out), + ); + } + #[test] fn test_exact_overlap_divergent_subtries() { check2( @@ -1678,6 +3187,15 @@ mod tests { ); } + #[test] + fn test_exact_overlap_divergent_subtries_n() { + checkn( + &PATHS_WITH_SAME_PREFIX_DIFFERENT_CHILDREN_N, + PATHS_WITH_SAME_PREFIX_DIFFERENT_CHILDREN_N[0], + |[mut z0, mut z1, mut z2, mut z3, mut z4, mut z5], mut out| zipper_subtract_n!(z0, z1, z2, z3, z4, z5 => out), + ); + } + #[test] fn test_zigzag() { let expected: Paths = &[ @@ -1719,5 +3237,14 @@ mod tests { }, ); } + + #[test] + fn test_root_values_n() { + checkn( + &PATHS_WITH_ROOT_VALS_AND_CHILDREN_N, + PATHS_WITH_ROOT_VALS_AND_CHILDREN_N[0], + |[mut z0, mut z1, mut z2, mut z3, mut z4, mut z5], mut out| zipper_subtract_n!(z0, z1, z2, z3, z4, z5 => out), + ); + } } }