From b11aac3a4bf980a644ab23f953ebce11cb80f066 Mon Sep 17 00:00:00 2001 From: Barnadrot Date: Thu, 7 May 2026 09:57:17 +0200 Subject: [PATCH 1/5] pw3-13: switch AIR mds_air_16 to FFT MDS (50 mults vs Karatsuba 72) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The Poseidon AIR constraint folder evaluates mds_air_16 8x per row across runtime types (F, EF, FPacking, EFPacking). Previously this used Karatsuba convolution (72 mults). Switch to the same FFT-MDS already used in the permute_simd hot path: DIT_FFT(lambda/16 ⊙ DIF_IFFT(state)), 50 mults. Saves 22 mults × 8 MDS calls per AIR row = 176 mults/row, ~10% reduction in AIR Poseidon eval mult count. AIR Poseidon eval is ~10% of CPU time in the e2e prover (eval_2_full_rounds_16 + eval_last_2_full_rounds_16 + Poseidon16Precompile::eval). The unpacked lambda_over_16 = (DIF_IFFT(MDS_CIRC_COL) * 16^-1) is factored out of the SimdPrecomputed branch and stored at the top of Precomputed; the SIMD branch reuses it (no duplication). FFT helpers (bt/dit/neg_dif/dif_ifft/dit_fft) are ungated from target_feature since they're pure generic Rust, and their bound is relaxed from Algebra to PrimeCharacteristicRing + Mul to match mds_circ_16 (so EFPacking, which lacks Algebra, is admitted). Predicted magnitude: medium (1.0-1.5%). --- .../koala-bear/src/poseidon1_koalabear_16.rs | 103 ++++++++---------- crates/lean_vm/src/tables/poseidon_16/mod.rs | 12 +- 2 files changed, 54 insertions(+), 61 deletions(-) diff --git a/crates/backend/koala-bear/src/poseidon1_koalabear_16.rs b/crates/backend/koala-bear/src/poseidon1_koalabear_16.rs index 41be54ff5..f80a9a7b2 100644 --- a/crates/backend/koala-bear/src/poseidon1_koalabear_16.rs +++ b/crates/backend/koala-bear/src/poseidon1_koalabear_16.rs @@ -25,86 +25,42 @@ const MDS_CIRC_COL: [KoalaBear; 16] = KoalaBear::new_array([1, 3, 13, 22, 67, 2, // Forward twiddles for 16-point FFT: W_k = omega^k // ========================================================================= -#[cfg(any( - all(target_arch = "aarch64", target_feature = "neon"), - all(target_arch = "x86_64", target_feature = "avx2") -))] const W1: KoalaBear = KoalaBear::new(0x08dbd69c); -#[cfg(any( - all(target_arch = "aarch64", target_feature = "neon"), - all(target_arch = "x86_64", target_feature = "avx2") -))] const W2: KoalaBear = KoalaBear::new(0x6832fe4a); -#[cfg(any( - all(target_arch = "aarch64", target_feature = "neon"), - all(target_arch = "x86_64", target_feature = "avx2") -))] const W3: KoalaBear = KoalaBear::new(0x27ae21e2); -#[cfg(any( - all(target_arch = "aarch64", target_feature = "neon"), - all(target_arch = "x86_64", target_feature = "avx2") -))] const W4: KoalaBear = KoalaBear::new(0x7e010002); -#[cfg(any( - all(target_arch = "aarch64", target_feature = "neon"), - all(target_arch = "x86_64", target_feature = "avx2") -))] const W5: KoalaBear = KoalaBear::new(0x3a89a025); -#[cfg(any( - all(target_arch = "aarch64", target_feature = "neon"), - all(target_arch = "x86_64", target_feature = "avx2") -))] const W6: KoalaBear = KoalaBear::new(0x174e3650); -#[cfg(any( - all(target_arch = "aarch64", target_feature = "neon"), - all(target_arch = "x86_64", target_feature = "avx2") -))] const W7: KoalaBear = KoalaBear::new(0x27dfce22); // ========================================================================= // 16-point FFT / IFFT (radix-2, fully unrolled, in-place) // ========================================================================= -#[cfg(any( - all(target_arch = "aarch64", target_feature = "neon"), - all(target_arch = "x86_64", target_feature = "avx2") -))] #[inline(always)] -fn bt>(v: &mut [R; 16], lo: usize, hi: usize) { +fn bt>(v: &mut [R; 16], lo: usize, hi: usize) { let (a, b) = (v[lo], v[hi]); v[lo] = a + b; v[hi] = a - b; } -#[cfg(any( - all(target_arch = "aarch64", target_feature = "neon"), - all(target_arch = "x86_64", target_feature = "avx2") -))] #[inline(always)] -fn dit>(v: &mut [R; 16], lo: usize, hi: usize, t: KoalaBear) { +fn dit>(v: &mut [R; 16], lo: usize, hi: usize, t: KoalaBear) { let a = v[lo]; let tb = v[hi] * t; v[lo] = a + tb; v[hi] = a - tb; } -#[cfg(any( - all(target_arch = "aarch64", target_feature = "neon"), - all(target_arch = "x86_64", target_feature = "avx2") -))] #[inline(always)] -fn neg_dif>(v: &mut [R; 16], lo: usize, hi: usize, t: KoalaBear) { +fn neg_dif>(v: &mut [R; 16], lo: usize, hi: usize, t: KoalaBear) { let (a, b) = (v[lo], v[hi]); v[lo] = a + b; v[hi] = (b - a) * t; } -#[cfg(any( - all(target_arch = "aarch64", target_feature = "neon"), - all(target_arch = "x86_64", target_feature = "avx2") -))] #[inline(always)] -fn dif_ifft_16_mut>(f: &mut [R; 16]) { +fn dif_ifft_16_mut>(f: &mut [R; 16]) { bt(f, 0, 8); neg_dif(f, 1, 9, W7); neg_dif(f, 2, 10, W6); @@ -139,12 +95,8 @@ fn dif_ifft_16_mut>(f: &mut [R; 16]) { bt(f, 14, 15); } -#[cfg(any( - all(target_arch = "aarch64", target_feature = "neon"), - all(target_arch = "x86_64", target_feature = "avx2") -))] #[inline(always)] -fn dit_fft_16_mut>(f: &mut [R; 16]) { +fn dit_fft_16_mut>(f: &mut [R; 16]) { bt(f, 0, 1); bt(f, 2, 3); bt(f, 4, 5); @@ -543,6 +495,11 @@ struct Precomputed { /// Length = POSEIDON1_PARTIAL_ROUNDS - 1. sparse_round_constants: Vec, + // --- FFT MDS eigenvalues (unpacked, for AIR / generic types) --- + /// `lambda_over_16[i]` = (DIF_IFFT(MDS_CIRC_COL))[i] * 16^{-1}. + /// Used by `mds_fft_16` to compute the circulant MDS via FFT. + lambda_over_16: [KoalaBear; 16], + // --- SIMD pre-packed constants (NEON / AVX2 / AVX512) --- #[cfg(any( all(target_arch = "aarch64", target_feature = "neon"), @@ -634,6 +591,16 @@ fn precomputed() -> &'static Precomputed { .map(|w| core::array::from_fn(|i| if i == 0 { mds_0_0 } else { w[i - 1] })) .collect(); + // --- FFT MDS eigenvalues (unpacked) --- + // C * x = DIT_FFT((lambda/16) ⊙ DIF_IFFT(x)) — same identity used in + // `permute_simd`, factored out for the AIR / generic mds_fft_16 path. + let lambda_over_16: [KoalaBear; 16] = { + let mut lambda_br = MDS_CIRC_COL; + dif_ifft_16_mut(&mut lambda_br); + let inv16 = KoalaBear::new(1997537281); // 16^{-1} mod p + lambda_br.map(|l| l * inv16) + }; + // --- SIMD pre-packed constants (same layout for NEON / AVX2 / AVX512) --- #[cfg(any( all(target_arch = "aarch64", target_feature = "neon"), @@ -675,10 +642,8 @@ fn precomputed() -> &'static Precomputed { let packed_fused_bias: [PackedKB; 16] = fused_bias.map(pack); // Pre-packed eigenvalues * INV16 (absorbs /16 into eigenvalues). - let mut lambda_br = MDS_CIRC_COL; - dif_ifft_16_mut(&mut lambda_br); - let inv16 = KoalaBear::new(1997537281); // 16^{-1} mod p - let packed_lambda_over_16: [PackedKB; 16] = core::array::from_fn(|i| pack(lambda_br[i] * inv16)); + // Reuse the unpacked lambda computed above. + let packed_lambda_over_16: [PackedKB; 16] = lambda_over_16.map(pack); SimdPrecomputed { packed_initial_rc, @@ -699,6 +664,7 @@ fn precomputed() -> &'static Precomputed { sparse_first_row, sparse_v, sparse_round_constants: scalar_round_constants, + lambda_over_16, #[cfg(any( all(target_arch = "aarch64", target_feature = "neon"), all(target_arch = "x86_64", target_feature = "avx2") @@ -708,6 +674,29 @@ fn precomputed() -> &'static Precomputed { }) } +/// Eigenvalues of the circulant MDS matrix, divided by 16 (the unnormalized +/// FFT round-trip scaling). Used by [`mds_fft_16`]. +#[inline(always)] +pub fn poseidon1_lambda_over_16() -> &'static [KoalaBear; 16] { + &precomputed().lambda_over_16 +} + +/// Circulant MDS multiply via 16-point FFT (50 mults vs 72 for Karatsuba). +/// +/// Computes `state = C * state = DIT_FFT((lambda/16) o DIF_IFFT(state))`. +/// Bitwise-identical to `mds_circ_16` but with fewer multiplications. +/// Used by the AIR constraint folder where MDS is evaluated per row over +/// (packed) field types. +#[inline(always)] +pub fn mds_fft_16>(state: &mut [R; 16]) { + let lambda = poseidon1_lambda_over_16(); + dif_ifft_16_mut(state); + for i in 0..16 { + state[i] = state[i] * lambda[i]; + } + dit_fft_16_mut(state); +} + // ========================================================================= // Round constants (Grain LFSR, matching Plonky3) // ========================================================================= diff --git a/crates/lean_vm/src/tables/poseidon_16/mod.rs b/crates/lean_vm/src/tables/poseidon_16/mod.rs index 5cffe5194..cd047aa21 100644 --- a/crates/lean_vm/src/tables/poseidon_16/mod.rs +++ b/crates/lean_vm/src/tables/poseidon_16/mod.rs @@ -5,9 +5,13 @@ use crate::{execution::memory::MemoryAccess, tables::poseidon_16::trace_gen::gen use backend::*; use utils::{ToUsize, poseidon16_compress}; -/// Dispatch `mds_circ_16` through concrete types. -/// For `SymbolicExpression` we use the dense form so the zkDSL generator can -/// emit `dot_product_be` precompile calls instead of Karatsuba arithmetic. +/// Dispatch the circulant MDS multiply through concrete types. +/// +/// - `SymbolicExpression`: dense matrix-vector form so the zkDSL generator can +/// emit `dot_product_be` precompile calls instead of Karatsuba arithmetic. +/// - Runtime field types (F, EF, FPacking, EFPacking): FFT-based MDS +/// (`mds_fft_16`, 50 mults) instead of Karatsuba (`mds_circ_16`, 72 mults). +/// Same algebraic result; ~30% fewer mults per call. #[inline(always)] fn mds_air_16(state: &mut [A; WIDTH]) { if TypeId::of::() == TypeId::of::>() { @@ -17,7 +21,7 @@ fn mds_air_16(state: &mut [A; WIDTH]) { macro_rules! dispatch { ($t:ty) => { if TypeId::of::() == TypeId::of::<$t>() { - mds_circ_16::<$t>(unsafe { &mut *(state as *mut [A; WIDTH] as *mut [$t; WIDTH]) }); + mds_fft_16::<$t>(unsafe { &mut *(state as *mut [A; WIDTH] as *mut [$t; WIDTH]) }); return; } }; From 2731904492a90866c933691914b6c311e6af195a Mon Sep 17 00:00:00 2001 From: Barnadrot Date: Fri, 8 May 2026 14:39:35 +0200 Subject: [PATCH 2/5] perf: sponge RATE=8->12 with capacity=4 + zk-DSL RATE=12 port Reduce Poseidon permutations per Merkle leaf by 22-32% by increasing the sponge absorption rate from 8 to 12 field elements per permutation call. Changes: - sponge.rs: relax RATE==OUT and WIDTH==OUT+RATE asserts, support arbitrary RATE - merkle.rs: SPONGE_RATE=12, padded_full_base_width helper, corrected n_zero_suffix_rate_chunks formula for RATE!=WIDTH/2 - verifier.rs: pad base_data to sponge-aligned length before hashing - hashing.py: zk-DSL slice_hash_rtl rewritten for RATE=12, @inline removed to fix conditional branch fall-through bug --- crates/backend/fiat-shamir/src/verifier.rs | 15 ++- crates/backend/symetric/src/sponge.rs | 56 ++++++++--- .../rec_aggregation/zkdsl_implem/hashing.py | 97 +++++++++++++++++-- crates/whir/src/merkle.rs | 62 ++++++++++-- 4 files changed, 200 insertions(+), 30 deletions(-) diff --git a/crates/backend/fiat-shamir/src/verifier.rs b/crates/backend/fiat-shamir/src/verifier.rs index 9bbc26bd7..4459a8d60 100644 --- a/crates/backend/fiat-shamir/src/verifier.rs +++ b/crates/backend/fiat-shamir/src/verifier.rs @@ -72,7 +72,20 @@ where // SAFETY: We've confirmed PF == KoalaBear let paths: PrunedMerklePaths = unsafe { std::mem::transmute(paths) }; let perm = default_koalabear_poseidon1_16(); - let hash_fn = |data: &[KoalaBear]| symetric::hash_slice::<_, _, 16, 8, DIGEST_LEN_FE>(&perm, data); + let hash_fn = |data: &[KoalaBear]| { + // Pad data up to the smallest sponge-aligned length so that + // (padded - WIDTH) is a multiple of RATE. The prover's + // build_merkle_tree_koalabear pads identically before hashing. + const W: usize = 16; + const R: usize = 12; + let mut padded_len = data.len().max(W); + while !(padded_len - W).is_multiple_of(R) { + padded_len += 1; + } + let mut buf: Vec = data.to_vec(); + buf.resize(padded_len, KoalaBear::default()); + symetric::hash_slice::<_, _, 16, 12, DIGEST_LEN_FE>(&perm, &buf) + }; let combine_fn = |left: &[KoalaBear; DIGEST_LEN_FE], right: &[KoalaBear; DIGEST_LEN_FE]| { symetric::compress(&perm, [*left, *right]) }; diff --git a/crates/backend/symetric/src/sponge.rs b/crates/backend/symetric/src/sponge.rs index ebea80a9e..a65b52e25 100644 --- a/crates/backend/symetric/src/sponge.rs +++ b/crates/backend/symetric/src/sponge.rs @@ -2,21 +2,22 @@ use crate::Compression; -// IV should have been added to data when necessary (typically: when the length of the data beeing hashed is not constant). Maybe we should re-add IV all the time for simplicity? -// assumes data length is a multiple of RATE (= 8 in practice). +// IV should have been added to data when necessary (typically: when the length of the data beeing hashed is not constant). +// Sponge construction with capacity = WIDTH - RATE. +// Constraint: data.len() >= WIDTH and (data.len() - WIDTH) is a multiple of RATE. pub fn hash_slice(comp: &Comp, data: &[T]) -> [T; OUT] where T: Default + Copy, Comp: Compression<[T; WIDTH]>, { - debug_assert!(RATE == OUT); - debug_assert!(WIDTH == OUT + RATE); - debug_assert!(data.len().is_multiple_of(RATE)); - let n_chunks = data.len() / RATE; - debug_assert!(n_chunks >= 2); + debug_assert!(OUT <= WIDTH); + debug_assert!(RATE <= WIDTH); + debug_assert!(data.len() >= WIDTH); + debug_assert!((data.len() - WIDTH).is_multiple_of(RATE)); let mut state: [T; WIDTH] = data[data.len() - WIDTH..].try_into().unwrap(); comp.compress_mut(&mut state); - for chunk_idx in (0..n_chunks - 2).rev() { + let n_remaining_chunks = (data.len() - WIDTH) / RATE; + for chunk_idx in (0..n_remaining_chunks).rev() { let offset = chunk_idx * RATE; state[WIDTH - RATE..].copy_from_slice(&data[offset..offset + RATE]); comp.compress_mut(&mut state); @@ -24,7 +25,8 @@ where state[..OUT].try_into().unwrap() } -/// Precompute sponge state after absorbing `n_zero_chunks` all-zero RATE-chunks. +/// Precompute sponge state after `n_zero_chunks - 1` zero compresses +/// (1 for initial WIDTH zeros + (n-2) RATE-zero absorbs). pub fn precompute_zero_suffix_state( comp: &Comp, n_zero_chunks: usize, @@ -33,8 +35,8 @@ where T: Default + Copy, Comp: Compression<[T; WIDTH]>, { - debug_assert!(RATE == OUT); - debug_assert!(WIDTH == OUT + RATE); + debug_assert!(OUT <= WIDTH); + debug_assert!(RATE <= WIDTH); debug_assert!(n_zero_chunks >= 2); let mut state = [T::default(); WIDTH]; comp.compress_mut(&mut state); @@ -58,8 +60,8 @@ where Comp: Compression<[T; WIDTH]>, I: IntoIterator, { - debug_assert!(RATE == OUT); - debug_assert!(WIDTH == OUT + RATE); + debug_assert!(OUT <= WIDTH); + debug_assert!(RATE <= WIDTH); let mut state = [T::default(); WIDTH]; let mut iter = rtl_iter.into_iter(); for pos in (0..WIDTH).rev() { @@ -106,3 +108,31 @@ where } state[..OUT].try_into().unwrap() } + +#[cfg(test)] +mod tests { + use super::*; + use koala_bear::{KoalaBear, default_koalabear_poseidon1_16}; + use field::PrimeCharacteristicRing; + + /// Verify hash_slice(D) == hash_rtl_iter(D.iter().rev()) for arbitrary D with valid length. + #[test] + fn hash_slice_matches_rtl_iter_rate12() { + let perm = default_koalabear_poseidon1_16(); + // 100 = 16 + 12*7, compatible with WIDTH=16, RATE=12 + let data: Vec = (0..100u32).map(|i| KoalaBear::from_u32(i)).collect(); + let h_slice = hash_slice::(&perm, &data); + let h_rtl = hash_rtl_iter::(&perm, data.iter().rev().copied()); + assert_eq!(h_slice, h_rtl, "hash_slice and hash_rtl_iter must agree on equivalent inputs"); + } + + /// Same as above but for the existing RATE=8 case. + #[test] + fn hash_slice_matches_rtl_iter_rate8() { + let perm = default_koalabear_poseidon1_16(); + let data: Vec = (0..64u32).map(|i| KoalaBear::from_u32(i)).collect(); + let h_slice = hash_slice::(&perm, &data); + let h_rtl = hash_rtl_iter::(&perm, data.iter().rev().copied()); + assert_eq!(h_slice, h_rtl, "hash_slice and hash_rtl_iter must agree on equivalent inputs (RATE=8)"); + } +} diff --git a/crates/rec_aggregation/zkdsl_implem/hashing.py b/crates/rec_aggregation/zkdsl_implem/hashing.py index ef501732a..f9a0870d6 100644 --- a/crates/rec_aggregation/zkdsl_implem/hashing.py +++ b/crates/rec_aggregation/zkdsl_implem/hashing.py @@ -54,14 +54,99 @@ def batch_hash_slice_rtl_const(num_queries, all_data_to_hash, all_resulting_hash return -@inline def slice_hash_rtl(data, num_chunks): - states = Array((num_chunks - 1) * DIGEST_LEN) + """RATE=12 sponge over data of length num_chunks * 8 base elements. + Pads internally so that the absorbed length is 16 + 12*k (sponge-aligned), + matching the native prover's padded_full_base_width helper. + + `num_chunks` is `Const`, so all arithmetic and the if-branches below + resolve at compile time. + + Algorithm (mirrors Rust hash_rtl_iter for RATE=12, WIDTH=16): + state = padded_data[L-16..L] # initial state from last 16 elements + compress(state) + for chunk_idx descending from k-1 to 0: + state[0..4] persists (capacity) + state[4..16] = padded_data[chunk_idx*12..(chunk_idx+1)*12] + compress(state) + return state[0..8] + """ + if num_chunks == 1: + # data_len = 8 ; pad to 16 ; one permute. + buf = Array(16) + for i in unroll(0, 8): + buf[i] = data[i] + for i in unroll(8, 16): + buf[i] = 0 + result = Array(DIGEST_LEN) + poseidon16_compress(buf, buf + DIGEST_LEN, result) + return result + if num_chunks == 4: + return slice_hash_rtl_rate12(data, 32, 40, 2) + if num_chunks == 5: + return slice_hash_rtl_rate12(data, 40, 40, 2) + if num_chunks == 8: + return slice_hash_rtl_rate12(data, 64, 64, 4) + if num_chunks == 10: + return slice_hash_rtl_rate12(data, 80, 88, 6) + if num_chunks == 16: + return slice_hash_rtl_rate12(data, 128, 136, 10) + if num_chunks == 20: + return slice_hash_rtl_rate12(data, 160, 160, 12) + print(num_chunks) + assert False, "slice_hash_rtl called with unsupported num_chunks" + + +def slice_hash_rtl_rate12(data, data_len: Const, padded_len: Const, n_chunks_12: Const): + """Internal helper for RATE=12 sponge with explicit padding params. + Pre: padded_len = 16 + n_chunks_12 * 12 ; padded_len >= data_len. + """ + if padded_len == data_len: + # No padding needed; absorb directly from data. + return slice_hash_rtl_rate12_no_pad(data, padded_len, n_chunks_12) + # Build a local padded buffer once, then absorb from it. + padded_data = Array(padded_len) + for i in unroll(0, data_len): + padded_data[i] = data[i] + for i in unroll(data_len, padded_len): + padded_data[i] = 0 + return slice_hash_rtl_rate12_no_pad(padded_data, padded_len, n_chunks_12) + + +def slice_hash_rtl_rate12_no_pad(padded_data, padded_len: Const, n_chunks_12: Const): + # states[k*8..(k+1)*8] is the 8-element output of round k's compress. + states = Array((n_chunks_12 + 1) * DIGEST_LEN) + + # Round 0: initial state from last 16 elements of padded_data. + poseidon16_compress( + padded_data + padded_len - 16, + padded_data + padded_len - 8, + states, + ) - poseidon16_compress(data + (num_chunks - 2) * DIGEST_LEN, data + (num_chunks - 1) * DIGEST_LEN, states) - for j in unroll(1, num_chunks - 1): - poseidon16_compress(states + (j - 1) * DIGEST_LEN, data + (num_chunks - 2 - j) * DIGEST_LEN, states + j * DIGEST_LEN) - return states + (num_chunks - 2) * DIGEST_LEN + # Subsequent rounds: absorb 12-element chunks RTL. + for j in unroll(0, n_chunks_12): + chunk_idx = n_chunks_12 - 1 - j + + # Build left input (8 elements): [capacity_4 || chunk[0..4]]. + buf = Array(DIGEST_LEN) + buf[0] = states[j * DIGEST_LEN + 0] + buf[1] = states[j * DIGEST_LEN + 1] + buf[2] = states[j * DIGEST_LEN + 2] + buf[3] = states[j * DIGEST_LEN + 3] + buf[4] = padded_data[chunk_idx * 12 + 0] + buf[5] = padded_data[chunk_idx * 12 + 1] + buf[6] = padded_data[chunk_idx * 12 + 2] + buf[7] = padded_data[chunk_idx * 12 + 3] + + # Right input: chunk[4..12]. Output -> states[(j+1)*8..(j+2)*8]. + poseidon16_compress( + buf, + padded_data + chunk_idx * 12 + 4, + states + (j + 1) * DIGEST_LEN, + ) + + return states + n_chunks_12 * DIGEST_LEN @inline diff --git a/crates/whir/src/merkle.rs b/crates/whir/src/merkle.rs index b5517cd09..119e4d614 100644 --- a/crates/whir/src/merkle.rs +++ b/crates/whir/src/merkle.rs @@ -8,6 +8,7 @@ use field::BasedVectorSpace; use field::ExtensionField; use field::Field; use field::PackedValue; +use field::PrimeCharacteristicRing; use koala_bear::{KoalaBear, QuinticExtensionFieldKB, default_koalabear_poseidon1_16}; use poly::*; @@ -55,6 +56,23 @@ pub(crate) fn merkle_commit>( } } +// Sponge rate for Merkle leaf hashing. WIDTH=16 (Poseidon1KoalaBear16) gives +// capacity = WIDTH - RATE = 4 with RATE=12. See iter 29 WIP analysis for the +// security tradeoff. +const SPONGE_RATE: usize = 12; +const SPONGE_WIDTH: usize = 16; + +/// Pad up so that (padded - WIDTH) is divisible by RATE. Used only for sponge +/// alignment; the protocol-visible leaf width stays unpadded. +#[inline] +fn padded_full_base_width(full_base_width: usize) -> usize { + let mut padded = full_base_width.max(SPONGE_WIDTH); + while !(padded - SPONGE_WIDTH).is_multiple_of(SPONGE_RATE) { + padded += 1; + } + padded +} + #[instrument(name = "build merkle tree", skip_all)] fn build_merkle_tree_koalabear( leaf: DenseMatrix, @@ -62,24 +80,43 @@ fn build_merkle_tree_koalabear( effective_base_width: usize, ) -> RoundMerkleTree { let perm = default_koalabear_poseidon1_16(); - let n_zero_suffix_rate_chunks = (full_base_width - effective_base_width) / 8; + // Internal padding for sponge alignment. NOT exposed to the protocol layer. + let padded_full_width = padded_full_base_width(full_base_width); + // n_zero_suffix_rate_chunks = number of "zero RATE-chunks" the precompute + // must cover so that the remaining iter (effective + n_pad elements, + // where n_pad rounds effective up to a multiple of RATE) takes the sponge + // exactly to padded_full_width. precompute(n) does n-1 compresses, + // absorbing 16 zeros initially + (n-2)*RATE more = WIDTH + (n-2)*RATE. + // Total: WIDTH + (n-2)*RATE + (effective + n_pad) = padded. + // Solving: n = 2 + (padded - WIDTH - effective - n_pad) / RATE. + let n_pad = (SPONGE_RATE - effective_base_width % SPONGE_RATE) % SPONGE_RATE; + let n_zero_suffix_rate_chunks = if padded_full_width >= SPONGE_WIDTH + effective_base_width + n_pad { + 2 + (padded_full_width - SPONGE_WIDTH - effective_base_width - n_pad) / SPONGE_RATE + } else { + 0 + }; let first_layer = if n_zero_suffix_rate_chunks >= 2 { - let scalar_state = symetric::precompute_zero_suffix_state::( + let scalar_state = symetric::precompute_zero_suffix_state::( &perm, n_zero_suffix_rate_chunks, ); - let packed_state: [PFPacking; 16] = + let packed_state: [PFPacking; SPONGE_WIDTH] = std::array::from_fn(|i| PFPacking::::from_fn(|_| scalar_state[i])); - first_digest_layer_with_initial_state::, _, _, DIGEST_ELEMS, 16, 8>( + first_digest_layer_with_initial_state::, _, _, DIGEST_ELEMS, SPONGE_WIDTH, SPONGE_RATE>( &perm, &leaf, &packed_state, effective_base_width, ) } else { - first_digest_layer::, _, _, DIGEST_ELEMS, 16, 8>(&perm, &leaf, full_base_width) + first_digest_layer::, _, _, DIGEST_ELEMS, SPONGE_WIDTH, SPONGE_RATE>( + &perm, + &leaf, + padded_full_width, + ) }; - let tree = symetric::merkle::MerkleTree::from_first_layer::, _, 16>(&perm, first_layer); + let tree = symetric::merkle::MerkleTree::from_first_layer::, _, SPONGE_WIDTH>(&perm, first_layer); + // Expose UNPADDED width to the protocol; padding is purely a sponge detail. WhirMerkleTree { leaf, tree, @@ -125,8 +162,11 @@ pub(crate) fn merkle_verify>( let merkle_root = unsafe { std::mem::transmute_copy::<_, [KoalaBear; DIGEST_ELEMS]>(&merkle_root) }; let data = unsafe { std::mem::transmute::<_, Vec>(data) }; let proof = unsafe { std::mem::transmute::<_, &Vec<[KoalaBear; DIGEST_ELEMS]>>(proof) }; - let base_data = QuinticExtensionFieldKB::flatten_to_base(data); - symetric::merkle::merkle_verify::<_, _, DIGEST_ELEMS, 16, 8>( + let mut base_data = QuinticExtensionFieldKB::flatten_to_base(data); + // Pad to the sponge-aligned width (matches the prover's internal padding). + let padded = padded_full_base_width(base_data.len()); + base_data.resize(padded, KoalaBear::ZERO); + symetric::merkle::merkle_verify::<_, _, DIGEST_ELEMS, SPONGE_WIDTH, SPONGE_RATE>( &perm, &merkle_root, log_max_height, @@ -138,8 +178,10 @@ pub(crate) fn merkle_verify>( let merkle_root = unsafe { std::mem::transmute_copy::<_, [KoalaBear; DIGEST_ELEMS]>(&merkle_root) }; let data = unsafe { std::mem::transmute::<_, Vec>(data) }; let proof = unsafe { std::mem::transmute::<_, &Vec<[KoalaBear; DIGEST_ELEMS]>>(proof) }; - let base_data = KoalaBear::flatten_to_base(data); - symetric::merkle::merkle_verify::<_, _, DIGEST_ELEMS, 16, 8>( + let mut base_data = KoalaBear::flatten_to_base(data); + let padded = padded_full_base_width(base_data.len()); + base_data.resize(padded, KoalaBear::ZERO); + symetric::merkle::merkle_verify::<_, _, DIGEST_ELEMS, SPONGE_WIDTH, SPONGE_RATE>( &perm, &merkle_root, log_max_height, From 2198c0b4df5b40bfb9d5cd656fc6c51a0f8dfb08 Mon Sep 17 00:00:00 2001 From: Barnadrot Date: Fri, 8 May 2026 14:39:56 +0200 Subject: [PATCH 3/5] perf: MMO feedforward sponge for 124-bit collision security Replace standard outer-sponge with Matyas-Meyer-Oseas (MMO) feedforward construction. Same Poseidon-16 permutation, same RATE=12, but collision security lifts from 62-bit to 124-bit by chaining the full 16-element state instead of just the 4-element capacity. Changes: - sponge.rs: mmo_hash_slice, mmo_hash_rtl_iter, mmo_precompute_zero_suffix_state with full-state feedforward (XOR pre-perm state into post-perm state) - merkle.rs: wire MMO hash functions into Merkle tree construction - verifier.rs: use MMO hash in verification path - poseidon_16: new poseidon16_permute precompile (16-element output) for zk-DSL recursive verifier, with AIR constraints and trace generation - hashing.py: zk-DSL updated to use MMO via poseidon16_permute precompile Security: standard sponge collision = c*log2(p)/2 = 62 bits (unshippable). MMO collision = b-bit birthday on full state output = 124 bits (meets target). Verified against: Coratger-Khovratovich-Wagner-Mennink 2026, SAFE proof (eprint 2023/520), Beetle (CHES 2018). --- crates/backend/fiat-shamir/src/verifier.rs | 2 +- crates/backend/symetric/src/merkle.rs | 4 +- crates/backend/symetric/src/sponge.rs | 186 ++++++++++++++++++ crates/lean_compiler/snark_lib.py | 14 ++ .../lean_compiler/src/a_simplify_lang/mod.rs | 7 +- .../lean_compiler/src/instruction_encoder.rs | 2 + crates/lean_prover/src/trace_gen.rs | 19 ++ crates/lean_vm/src/isa/instruction.rs | 19 +- crates/lean_vm/src/tables/poseidon_16/mod.rs | 117 +++++++++-- .../src/tables/poseidon_16/trace_gen.rs | 20 +- .../rec_aggregation/zkdsl_implem/hashing.py | 48 ++--- crates/utils/src/poseidon.rs | 8 + crates/whir/src/merkle.rs | 15 +- 13 files changed, 402 insertions(+), 59 deletions(-) diff --git a/crates/backend/fiat-shamir/src/verifier.rs b/crates/backend/fiat-shamir/src/verifier.rs index 4459a8d60..c9053d905 100644 --- a/crates/backend/fiat-shamir/src/verifier.rs +++ b/crates/backend/fiat-shamir/src/verifier.rs @@ -84,7 +84,7 @@ where } let mut buf: Vec = data.to_vec(); buf.resize(padded_len, KoalaBear::default()); - symetric::hash_slice::<_, _, 16, 12, DIGEST_LEN_FE>(&perm, &buf) + symetric::mmo_hash_slice::<_, _, 16, 12, DIGEST_LEN_FE>(&perm, &buf) }; let combine_fn = |left: &[KoalaBear; DIGEST_LEN_FE], right: &[KoalaBear; DIGEST_LEN_FE]| { symetric::compress(&perm, [*left, *right]) diff --git a/crates/backend/symetric/src/merkle.rs b/crates/backend/symetric/src/merkle.rs index 676e83f3e..f0834ae8d 100644 --- a/crates/backend/symetric/src/merkle.rs +++ b/crates/backend/symetric/src/merkle.rs @@ -98,14 +98,14 @@ pub fn merkle_verify bool where - F: Default + Copy + PartialEq, + F: field::PrimeCharacteristicRing + PartialEq, Comp: Compression<[F; WIDTH]>, { if opening_proof.len() != log_height { return false; } - let mut root = crate::hash_slice::<_, _, WIDTH, RATE, DIGEST_ELEMS>(comp, opened_values); + let mut root = crate::mmo_hash_slice::<_, _, WIDTH, RATE, DIGEST_ELEMS>(comp, opened_values); for &sibling in opening_proof.iter() { let (left, right) = if index & 1 == 0 { diff --git a/crates/backend/symetric/src/sponge.rs b/crates/backend/symetric/src/sponge.rs index a65b52e25..e2cf555ba 100644 --- a/crates/backend/symetric/src/sponge.rs +++ b/crates/backend/symetric/src/sponge.rs @@ -1,5 +1,7 @@ // Credits: Plonky3 (https://github.com/Plonky3/Plonky3) (MIT and Apache-2.0 licenses). +use field::PrimeCharacteristicRing; + use crate::Compression; // IV should have been added to data when necessary (typically: when the length of the data beeing hashed is not constant). @@ -109,6 +111,150 @@ where state[..OUT].try_into().unwrap() } +// ============================================================================= +// MMO-mode (Davies-Meyer / Matyas-Meyer-Oseas) feedforward sponge +// ============================================================================= +// +// Standard PaddingFreeSponge ("oSponge") collision security is c·log2(p)/2 bits +// because of the inner-state birthday attack on the capacity portion. With +// (WIDTH=16, RATE=12, capacity=4) over KoalaBear (p ~= 2^31), that bound is +// 4*31/2 = 62 bits — short of the 124-bit target. +// +// This MMO variant treats each absorb step as the Matyas-Meyer-Oseas +// compression F(state, M) = state + perm(state + (M, 0_cap)), i.e. message is +// ADDED into the rate positions (not overwritten) and the full pre-perm state +// is fed forward. The chaining variable is then the FULL 16-element state +// (496 bits), not the 4-element capacity, so generic compression collision is +// 2^{b/2} = 2^248 in the random-permutation model, and after truncation to +// OUT=8 elements the digest birthday gives 2^{output_bits/2} = 2^124. +// +// IMPORTANT: `Compression::compress_mut` for Poseidon-16 in this codebase ALREADY +// computes `output = perm(input) + input` (matching the AIR's `eval_last_2_full_rounds_16` +// which adds initial state to post-perm state — see lean_vm/src/tables/poseidon_16/mod.rs). +// So a single `compress_mut` call IS one MMO step; we must NOT add `prev` again +// after, or we'd double-feedforward and disagree with the zk-DSL precompile. +// +// Convention matches the existing hash_slice: the first 16 elements of data +// are loaded directly into the state and the precompile is invoked once +// (zero IV implicit). Subsequent RATE-sized blocks are absorbed with ADD into +// rate positions, then a single compression invocation gives the next state. + +/// MMO-mode (feedforward) variant of `hash_slice`. Same input format and +/// alignment requirements; collision security is bounded by the digest size +/// rather than the capacity. +pub fn mmo_hash_slice(comp: &Comp, data: &[T]) -> [T; OUT] +where + T: PrimeCharacteristicRing, + Comp: Compression<[T; WIDTH]>, +{ + debug_assert!(OUT <= WIDTH); + debug_assert!(RATE <= WIDTH); + debug_assert!(data.len() >= WIDTH); + debug_assert!((data.len() - WIDTH).is_multiple_of(RATE)); + let mut state: [T; WIDTH] = data[data.len() - WIDTH..].try_into().unwrap(); + // First MMO compression: state ← perm(state) + state (compress_mut already does this). + comp.compress_mut(&mut state); + let n_remaining_chunks = (data.len() - WIDTH) / RATE; + for chunk_idx in (0..n_remaining_chunks).rev() { + let offset = chunk_idx * RATE; + // ADD message into rate positions (not overwrite). + for i in 0..RATE { + state[WIDTH - RATE + i] += data[offset + i]; + } + // One MMO compression: state ← perm(state) + state. compress_mut already + // performs the full-state feedforward. + comp.compress_mut(&mut state); + } + state[..OUT].try_into().unwrap() +} + +/// MMO-mode variant of `precompute_zero_suffix_state`. Same number of perm +/// calls as the standard variant (n_zero_chunks - 1 total). +pub fn mmo_precompute_zero_suffix_state( + comp: &Comp, + n_zero_chunks: usize, +) -> [T; WIDTH] +where + T: PrimeCharacteristicRing, + Comp: Compression<[T; WIDTH]>, +{ + debug_assert!(OUT <= WIDTH); + debug_assert!(RATE <= WIDTH); + debug_assert!(n_zero_chunks >= 2); + let mut state = [T::ZERO; WIDTH]; + // First absorb (16 zeros). compress_mut applies one MMO compression. + comp.compress_mut(&mut state); + // Subsequent (n_zero_chunks - 2) absorbs of zero RATE-chunks. ADD 0 is a + // no-op, so each iteration is just one MMO compression. + for _ in 0..n_zero_chunks - 2 { + comp.compress_mut(&mut state); + } + state +} + +/// RTL = Right-to-left. MMO-mode counterpart of `hash_rtl_iter`. +#[inline(always)] +pub fn mmo_hash_rtl_iter( + comp: &Comp, + rtl_iter: I, +) -> [T; OUT] +where + T: PrimeCharacteristicRing, + Comp: Compression<[T; WIDTH]>, + I: IntoIterator, +{ + debug_assert!(OUT <= WIDTH); + debug_assert!(RATE <= WIDTH); + let mut state = [T::ZERO; WIDTH]; + let mut iter = rtl_iter.into_iter(); + for pos in (0..WIDTH).rev() { + state[pos] = iter.next().unwrap(); + } + comp.compress_mut(&mut state); + mmo_absorb_rtl_chunks::(comp, &mut state, &mut iter) +} + +/// RTL = Right-to-left. MMO-mode counterpart of `hash_rtl_iter_with_initial_state`. +#[inline(always)] +pub fn mmo_hash_rtl_iter_with_initial_state( + comp: &Comp, + mut iter: I, + initial_state: &[T; WIDTH], +) -> [T; OUT] +where + T: PrimeCharacteristicRing, + Comp: Compression<[T; WIDTH]>, + I: Iterator, +{ + let mut state = *initial_state; + mmo_absorb_rtl_chunks::(comp, &mut state, &mut iter) +} + +/// RTL = Right-to-left. MMO-mode chunk absorption: ADD message into rate, then +/// one MMO compression (compress_mut already does perm + input feedforward). +#[inline(always)] +fn mmo_absorb_rtl_chunks( + comp: &Comp, + state: &mut [T; WIDTH], + iter: &mut I, +) -> [T; OUT] +where + T: PrimeCharacteristicRing, + Comp: Compression<[T; WIDTH]>, + I: Iterator, +{ + while let Some(elem) = iter.next() { + // ADD into rate positions (last RATE elements), reading the iterator + // from right to left. + state[WIDTH - 1] += elem; + for pos in (WIDTH - RATE..WIDTH - 1).rev() { + state[pos] += iter.next().unwrap(); + } + comp.compress_mut(state); + } + state[..OUT].try_into().unwrap() +} + #[cfg(test)] mod tests { use super::*; @@ -135,4 +281,44 @@ mod tests { let h_rtl = hash_rtl_iter::(&perm, data.iter().rev().copied()); assert_eq!(h_slice, h_rtl, "hash_slice and hash_rtl_iter must agree on equivalent inputs (RATE=8)"); } + + /// MMO-mode counterpart of hash_slice_matches_rtl_iter_rate12. + #[test] + fn mmo_hash_slice_matches_rtl_iter_rate12() { + let perm = default_koalabear_poseidon1_16(); + let data: Vec = (0..100u32).map(|i| KoalaBear::from_u32(i)).collect(); + let h_slice = mmo_hash_slice::(&perm, &data); + let h_rtl = mmo_hash_rtl_iter::(&perm, data.iter().rev().copied()); + assert_eq!(h_slice, h_rtl, "mmo_hash_slice and mmo_hash_rtl_iter must agree on equivalent inputs"); + } + + /// MMO-mode is structurally distinct from oSponge — verify they produce + /// different digests on the same input (sanity check that we are not + /// accidentally falling back to the standard sponge). + #[test] + fn mmo_differs_from_standard_sponge() { + let perm = default_koalabear_poseidon1_16(); + let data: Vec = (0..28u32).map(|i| KoalaBear::from_u32(i)).collect(); // 16 + 12, two-block input + let h_std = hash_slice::(&perm, &data); + let h_mmo = mmo_hash_slice::(&perm, &data); + assert_ne!(h_std, h_mmo, "MMO must differ from standard sponge for multi-block inputs"); + } + + /// Verify the MMO precompute is consistent with directly hashing zeros. + #[test] + fn mmo_precompute_zero_suffix_matches_full_zero_hash() { + let perm = default_koalabear_poseidon1_16(); + let n_zero_chunks: usize = 4; // WIDTH absorb + 3 RATE absorbs of zero + let zeros: Vec = + std::iter::repeat_n(KoalaBear::ZERO, 16 + 12 * (n_zero_chunks - 1)).collect(); + let direct = mmo_hash_slice::(&perm, &zeros); + let pre = mmo_precompute_zero_suffix_state::(&perm, n_zero_chunks); + // The precompute does (n_zero_chunks - 1) MMO compressions; mmo_hash_slice + // does n_zero_chunks total. To finalize we need ONE MORE compression + // (ADDing zero rate is a no-op). + let mut state = pre; + perm.compress_mut(&mut state); + let advanced: [KoalaBear; 8] = state[..8].try_into().unwrap(); + assert_eq!(advanced, direct); + } } diff --git a/crates/lean_compiler/snark_lib.py b/crates/lean_compiler/snark_lib.py index 5d13b761f..e15e474f4 100644 --- a/crates/lean_compiler/snark_lib.py +++ b/crates/lean_compiler/snark_lib.py @@ -89,6 +89,20 @@ def poseidon16_compress_half_hardcoded_left(left, right, output, offset): _ = left, right, output, offset +def poseidon16_permute(left, right, output): + """Apply Poseidon-16 with input feedforward (MMO compression) and write all + 16 output elements to memory[output..output+16]. + + output[0..8] = perm(left || right)[0..8] + left + output[8..16] = perm(left || right)[8..16] + right + + Used for MMO sponge leaf hashing — the FULL 16-element state must be + chained between rounds to achieve `output_bits/2 = 124`-bit collision + security. Allocate Array(16) (NOT Array(8)) for the result. + """ + _ = left, right, output + + def add_be(a, b, result, length=None): _ = a, b, result, length diff --git a/crates/lean_compiler/src/a_simplify_lang/mod.rs b/crates/lean_compiler/src/a_simplify_lang/mod.rs index b3c121d88..451d5ab5f 100644 --- a/crates/lean_compiler/src/a_simplify_lang/mod.rs +++ b/crates/lean_compiler/src/a_simplify_lang/mod.rs @@ -7,7 +7,8 @@ use crate::{ use backend::PrimeCharacteristicRing; use lean_vm::{ ALL_POSEIDON16_NAMES, Boolean, BooleanExpr, CustomHint, ExtensionOpMode, FunctionName, - POSEIDON16_HALF_HARDCODED_LEFT_NAME, POSEIDON16_HALF_NAME, POSEIDON16_HARDCODED_LEFT_NAME, PrecompileArgs, + POSEIDON16_HALF_HARDCODED_LEFT_NAME, POSEIDON16_HALF_NAME, POSEIDON16_HARDCODED_LEFT_NAME, + POSEIDON16_PERMUTE_NAME, PrecompileArgs, PrecompileCompTimeArgs, SourceLocation, }; use std::{ @@ -2259,7 +2260,7 @@ fn simplify_lines( continue; } - // Special handling for poseidon16 precompile (4 variants) + // Special handling for poseidon16 precompile (5 variants) if ALL_POSEIDON16_NAMES.contains(&function_name.as_str()) { if !targets.is_empty() { return Err(format!( @@ -2268,6 +2269,7 @@ fn simplify_lines( } let half_output = [POSEIDON16_HALF_NAME, POSEIDON16_HALF_HARDCODED_LEFT_NAME] .contains(&function_name.as_str()); + let full_output = function_name.as_str() == POSEIDON16_PERMUTE_NAME; let is_hardcoded_left = [POSEIDON16_HARDCODED_LEFT_NAME, POSEIDON16_HALF_HARDCODED_LEFT_NAME] .contains(&function_name.as_str()); @@ -2302,6 +2304,7 @@ fn simplify_lines( res: simplified_args[2].clone(), data: PrecompileCompTimeArgs::Poseidon16 { half_output, + full_output, hardcoded_offset_left, }, })); diff --git a/crates/lean_compiler/src/instruction_encoder.rs b/crates/lean_compiler/src/instruction_encoder.rs index 1060e3be4..42b284624 100644 --- a/crates/lean_compiler/src/instruction_encoder.rs +++ b/crates/lean_compiler/src/instruction_encoder.rs @@ -50,6 +50,7 @@ pub fn field_representation(instr: &Instruction) -> [F; N_INSTRUCTION_COLUMNS] { let precompile_data = match &precompile.data { PrecompileCompTimeArgs::Poseidon16 { half_output, + full_output, hardcoded_offset_left, } => { let flag_left = hardcoded_offset_left.is_some() as usize; @@ -58,6 +59,7 @@ pub fn field_representation(instr: &Instruction) -> [F; N_INSTRUCTION_COLUMNS] { + POSEIDON_HALF_OUTPUT_SHIFT * (*half_output as usize) + POSEIDON_HARDCODED_LEFT_4_FLAG_SHIFT * flag_left + POSEIDON_HARDCODED_LEFT_4_OFFSET_SHIFT * hardcoded_offset_left_val + + POSEIDON_FULL_OUTPUT_SHIFT * (*full_output as usize) } PrecompileCompTimeArgs::ExtensionOp { size, mode } => { assert!(*size >= 1, "invalid extension_op size={size}"); diff --git a/crates/lean_prover/src/trace_gen.rs b/crates/lean_prover/src/trace_gen.rs index 1801a5b62..fe3c4eed7 100644 --- a/crates/lean_prover/src/trace_gen.rs +++ b/crates/lean_prover/src/trace_gen.rs @@ -130,6 +130,25 @@ pub fn get_execution_trace(bytecode: &Bytecode, execution_result: ExecutionResul }); } + // For non-full-output rows, zero outputs_high (AIR constrains them to zero) and point + // index_input_res_high at the zero-vec region so the high-half memory lookup is a no-op + // (m[zero_vec_ptr + i] = 0 = outputs_high[i]). + { + // Snapshot flag column (immutable copy) before taking mutable references to the trace. + let full_output_flags: Vec = + poseidon_trace.columns[POSEIDON_16_COL_FLAG_FULL_OUTPUT].clone(); + let zero_ptr = F::from_usize(padding_zero_vec_ptr); + let n_rows = full_output_flags.len(); + for row_idx in 0..n_rows { + if full_output_flags[row_idx] != F::ONE { + poseidon_trace.columns[POSEIDON_16_COL_INDEX_INPUT_RES_HIGH][row_idx] = zero_ptr; + for j in 0..DIGEST_LEN { + poseidon_trace.columns[POSEIDON_16_COL_OUTPUTS_HIGH_START + j][row_idx] = F::ZERO; + } + } + } + } + let extension_op_trace = traces.get_mut(&Table::extension_op()).unwrap(); fill_trace_extension_op(extension_op_trace, &memory_padded); diff --git a/crates/lean_vm/src/isa/instruction.rs b/crates/lean_vm/src/isa/instruction.rs index f0b7ef212..c3e42fd4d 100644 --- a/crates/lean_vm/src/isa/instruction.rs +++ b/crates/lean_vm/src/isa/instruction.rs @@ -65,6 +65,9 @@ pub struct PrecompileArgs { pub enum PrecompileCompTimeArgs { Poseidon16 { half_output: bool, + /// Permute mode: write all 16 elements of perm(input)+input to memory at `res`. + /// Mutually exclusive with `half_output`. Used by the MMO sponge leaf hash. + full_output: bool, // hardcoded_offset_left = None: left_input = m[arg_a..arg_a+8] // hardcoded_offset_left = Some(offset_left): left_input = m[offset_left..offset_left+4] | m[arg_a..arg_a+4] (arg_a is the first runtime parameter) hardcoded_offset_left: Option, @@ -87,9 +90,11 @@ impl PrecompileCompTimeArgs { match self { Self::Poseidon16 { half_output, + full_output, hardcoded_offset_left: hardcoded_left_4, } => PrecompileCompTimeArgs::Poseidon16 { half_output, + full_output, hardcoded_offset_left: hardcoded_left_4.map(&mut f), }, Self::ExtensionOp { size, mode } => PrecompileCompTimeArgs::ExtensionOp { size: f(size), mode }, @@ -252,12 +257,16 @@ impl Display for PrecompileArgs { match data { PrecompileCompTimeArgs::Poseidon16 { half_output, + full_output, hardcoded_offset_left: hardcoded_left_4, - } => match (*half_output, hardcoded_left_4) { - (false, None) => write!(f, "{POSEIDON16_NAME}({arg_0}, {arg_1}, {res})"), - (true, None) => write!(f, "{POSEIDON16_NAME}({arg_0}, {arg_1}, {res}, half)"), - (false, Some(off)) => write!(f, "{POSEIDON16_NAME}({arg_0}, {arg_1}, {res}, hardcoded_left_4={off})"), - (true, Some(off)) => write!( + } => match (*full_output, *half_output, hardcoded_left_4) { + (true, _, _) => write!(f, "poseidon16_permute({arg_0}, {arg_1}, {res})"), + (false, false, None) => write!(f, "{POSEIDON16_NAME}({arg_0}, {arg_1}, {res})"), + (false, true, None) => write!(f, "{POSEIDON16_NAME}({arg_0}, {arg_1}, {res}, half)"), + (false, false, Some(off)) => { + write!(f, "{POSEIDON16_NAME}({arg_0}, {arg_1}, {res}, hardcoded_left_4={off})") + } + (false, true, Some(off)) => write!( f, "{POSEIDON16_NAME}({arg_0}, {arg_1}, {res}, half, hardcoded_left_4={off})" ), diff --git a/crates/lean_vm/src/tables/poseidon_16/mod.rs b/crates/lean_vm/src/tables/poseidon_16/mod.rs index cd047aa21..35e2b83b1 100644 --- a/crates/lean_vm/src/tables/poseidon_16/mod.rs +++ b/crates/lean_vm/src/tables/poseidon_16/mod.rs @@ -98,6 +98,9 @@ pub const POSEIDON_PRECOMPILE_DATA: usize = 1; pub const POSEIDON_HALF_OUTPUT_SHIFT: usize = 1 << 1; pub const POSEIDON_HARDCODED_LEFT_4_FLAG_SHIFT: usize = 1 << 2; pub const POSEIDON_HARDCODED_LEFT_4_OFFSET_SHIFT: usize = 1 << 3; +// Bit 30 is safely beyond `8 * MAX_LOG_MEMORY_SIZE = 2^29` so it cannot +// alias with `POSEIDON_HARDCODED_LEFT_4_OFFSET_SHIFT * offset`. +pub const POSEIDON_FULL_OUTPUT_SHIFT: usize = 1 << 30; pub const POSEIDON_16_COL_FLAG: ColIndex = 0; pub const POSEIDON_16_COL_INDEX_INPUT_RIGHT: ColIndex = 1; @@ -108,7 +111,14 @@ pub const POSEIDON_16_COL_OFFSET_LEFT_HARDCODED: ColIndex = 5; pub const POSEIDON_16_COL_EFFECTIVE_INDEX_LEFT_FIRST: ColIndex = 6; pub const POSEIDON_16_COL_EFFECTIVE_INDEX_LEFT_SECOND: ColIndex = 7; pub const POSEIDON_16_COL_INPUT_START: ColIndex = 8; -pub const POSEIDON_16_COL_OUTPUT_START: ColIndex = num_cols_poseidon_16() - 8; +// Layout at end of struct (in field-declaration order): +// ... outputs (DIGEST_LEN cols) ... flag_full_output (1) ... index_input_res_high (1) ... outputs_high (DIGEST_LEN) +// So OUTPUTS_HIGH_START = num_cols - DIGEST_LEN, INDEX_INPUT_RES_HIGH = num_cols - DIGEST_LEN - 1, +// FLAG_FULL_OUTPUT = num_cols - DIGEST_LEN - 2, OUTPUT_START = num_cols - DIGEST_LEN - 2 - DIGEST_LEN. +pub const POSEIDON_16_COL_OUTPUTS_HIGH_START: ColIndex = num_cols_poseidon_16() - DIGEST_LEN; +pub const POSEIDON_16_COL_INDEX_INPUT_RES_HIGH: ColIndex = POSEIDON_16_COL_OUTPUTS_HIGH_START - 1; +pub const POSEIDON_16_COL_FLAG_FULL_OUTPUT: ColIndex = POSEIDON_16_COL_INDEX_INPUT_RES_HIGH - 1; +pub const POSEIDON_16_COL_OUTPUT_START: ColIndex = POSEIDON_16_COL_FLAG_FULL_OUTPUT - DIGEST_LEN; /// Non-committed columns ("virtual"): pub const POSEIDON_16_COL_INDEX_INPUT_LEFT: ColIndex = num_cols_poseidon_16(); pub const POSEIDON_16_COL_PRECOMPILE_DATA: ColIndex = num_cols_poseidon_16() + 1; @@ -117,11 +127,16 @@ pub const POSEIDON16_NAME: &str = "poseidon16_compress"; pub const POSEIDON16_HALF_NAME: &str = "poseidon16_compress_half"; pub const POSEIDON16_HARDCODED_LEFT_NAME: &str = "poseidon16_compress_hardcoded_left"; pub const POSEIDON16_HALF_HARDCODED_LEFT_NAME: &str = "poseidon16_compress_half_hardcoded_left"; -pub const ALL_POSEIDON16_NAMES: [&str; 4] = [ +/// Permute mode: writes ALL 16 perm-output elements (with input feedforward) to memory. +/// Used for MMO sponge leaf hashing where the FULL 16-element state must be chained +/// between rounds to achieve `output_bits/2 = 124`-bit collision security at any rate. +pub const POSEIDON16_PERMUTE_NAME: &str = "poseidon16_permute"; +pub const ALL_POSEIDON16_NAMES: [&str; 5] = [ POSEIDON16_NAME, POSEIDON16_HALF_NAME, POSEIDON16_HARDCODED_LEFT_NAME, POSEIDON16_HALF_HARDCODED_LEFT_NAME, + POSEIDON16_PERMUTE_NAME, ]; pub const HALF_DIGEST_LEN: usize = DIGEST_LEN / 2; @@ -157,6 +172,13 @@ impl TableT for Poseidon16Precompile { index: POSEIDON_16_COL_INDEX_INPUT_RES, values: (POSEIDON_16_COL_OUTPUT_START..POSEIDON_16_COL_OUTPUT_START + DIGEST_LEN).collect(), }, + // High-half output lookup (only meaningful in permute mode, but always active). + // For non-permute rows the trace_gen sets index_input_res_high = zero_vec_ptr and + // outputs_high = 0, so this lookup checks `m[zero_vec_ptr+i] == 0` (trivially true). + LookupIntoMemory { + index: POSEIDON_16_COL_INDEX_INPUT_RES_HIGH, + values: (POSEIDON_16_COL_OUTPUTS_HIGH_START..POSEIDON_16_COL_OUTPUTS_HIGH_START + DIGEST_LEN).collect(), + }, ] } @@ -194,11 +216,22 @@ impl TableT for Poseidon16Precompile { *perm.offset_hardcoded_left = F::ZERO; *perm.effective_index_left_first = F::from_usize(zero_vec_ptr); *perm.effective_index_left_second = F::from_usize(zero_vec_ptr + HALF_DIGEST_LEN); + *perm.flag_full_output = F::ZERO; + // Padding rows are non-permute → high-output index points at zero_vec_ptr (a 16-cell zero region). + *perm.index_input_res_high = F::from_usize(zero_vec_ptr); + // outputs_high is zeroed by the constraint `(1 - flag_full_output) * outputs_high[i] = 0`; + // the trace generator below leaves them at F::ZERO via the Vec::new() default. // Non-committed columns row[POSEIDON_16_COL_INDEX_INPUT_LEFT] = F::from_usize(zero_vec_ptr); row[POSEIDON_16_COL_PRECOMPILE_DATA] = F::from_usize(POSEIDON_PRECOMPILE_DATA); generate_trace_rows_for_perm(perm); + // generate_trace_rows_for_perm fills outputs[0..8] with state[i] + inputs[i]; for padding + // rows inputs are all zero so outputs ≡ Poseidon-16(0) (8 elements). outputs_high however + // must be zero per the AIR constraint, so explicitly clear it after the perm trace fill. + for output_high in perm.outputs_high.iter_mut() { + **output_high = F::ZERO; + } row } @@ -213,11 +246,13 @@ impl TableT for Poseidon16Precompile { ) -> Result<(), RunnerError> { let PrecompileCompTimeArgs::Poseidon16 { half_output, + full_output, hardcoded_offset_left, } = args else { unreachable!("Poseidon16 table called with non-Poseidon16 args"); }; + debug_assert!(!(half_output && full_output), "half_output and full_output are mutually exclusive"); let trace = ctx.traces.get_mut(&self.table()).unwrap(); let arg_a_usize = arg_a.to_usize(); @@ -242,13 +277,17 @@ impl TableT for Poseidon16Precompile { input[HALF_DIGEST_LEN..DIGEST_LEN].copy_from_slice(&arg0_second); input[DIGEST_LEN..].copy_from_slice(&arg1); - let output = poseidon16_compress(input); - - if half_output { - ctx.memory - .set_slice(index_res_a.to_usize(), &output[..HALF_DIGEST_LEN])?; + let res_addr = index_res_a.to_usize(); + if full_output { + // Write all 16 elements (perm output + input feedforward) to memory. + let full = utils::poseidon16_permute_full(input); + ctx.memory.set_slice(res_addr, &full)?; + } else if half_output { + let output = poseidon16_compress(input); + ctx.memory.set_slice(res_addr, &output[..HALF_DIGEST_LEN])?; } else { - ctx.memory.set_slice(index_res_a.to_usize(), &output)?; + let output = poseidon16_compress(input); + ctx.memory.set_slice(res_addr, &output)?; } let hardcoded_offset_left_val = hardcoded_offset_left.unwrap_or(0); @@ -264,12 +303,23 @@ impl TableT for Poseidon16Precompile { for (i, value) in input.iter().enumerate() { trace.columns[POSEIDON_16_COL_INPUT_START + i].push(*value); } + trace.columns[POSEIDON_16_COL_FLAG_FULL_OUTPUT].push(if full_output { F::ONE } else { F::ZERO }); + // index_input_res_high: real address (res+8) when in permute mode; otherwise a placeholder + // that will be rewritten in lean_prover post-processing to a zero region. Use 0 for now; + // soundness is maintained because the AIR constraint + // `flag_full_output * (index_input_res_high - index_input_res - 8) = 0` + // only forces the value when permute mode is on. + let index_high = if full_output { res_addr + DIGEST_LEN } else { 0 }; + trace.columns[POSEIDON_16_COL_INDEX_INPUT_RES_HIGH].push(F::from_usize(index_high)); + // outputs_high columns are filled by trace_gen (perm output high half) for permute rows, + // and overwritten to zero in lean_prover post-processing for non-permute rows. // Non-committed columns trace.columns[POSEIDON_16_COL_INDEX_INPUT_LEFT].push(arg_a); let precompile_data = POSEIDON_PRECOMPILE_DATA + POSEIDON_HALF_OUTPUT_SHIFT * (half_output as usize) + POSEIDON_HARDCODED_LEFT_4_FLAG_SHIFT * (flag_hardcoded as usize) - + POSEIDON_HARDCODED_LEFT_4_OFFSET_SHIFT * hardcoded_offset_left_val; + + POSEIDON_HARDCODED_LEFT_4_OFFSET_SHIFT * hardcoded_offset_left_val + + POSEIDON_FULL_OUTPUT_SHIFT * (full_output as usize); trace.columns[POSEIDON_16_COL_PRECOMPILE_DATA].push(F::from_usize(precompile_data)); // the rest of the trace is filled at the end of the execution (to get parallelism + SIMD) @@ -294,7 +344,9 @@ impl Air for Poseidon16Precompile { vec![] } fn n_constraints(&self) -> usize { - BUS as usize + 80 + // 80 (existing) + 1 (full_output bool) + 1 (full*half mutex) + 1 (high index) + // + 8 (full * (outputs_high[i] - state - input)) + 8 ((1-full) * outputs_high[i]) + BUS as usize + 80 + 19 } fn eval(&self, builder: &mut AB, extra_data: &Self::ExtraData) { let cols: Poseidon1Cols16 = { @@ -311,7 +363,8 @@ impl Air for Poseidon16Precompile { + cols.flag_hardcoded_left * AB::F::from_usize(POSEIDON_HARDCODED_LEFT_4_FLAG_SHIFT) + cols.flag_hardcoded_left * cols.offset_hardcoded_left - * AB::F::from_usize(POSEIDON_HARDCODED_LEFT_4_OFFSET_SHIFT); + * AB::F::from_usize(POSEIDON_HARDCODED_LEFT_4_OFFSET_SHIFT) + + cols.flag_full_output * AB::F::from_usize(POSEIDON_FULL_OUTPUT_SHIFT); // effective_index_left_first = index_a * (1 - flag_hardcoded_left_4) + offset * flag_hardcoded_left_4 let one_minus_flag_hardcoded_left = AB::IF::ONE - cols.flag_hardcoded_left; @@ -333,6 +386,17 @@ impl Air for Poseidon16Precompile { builder.assert_bool(cols.flag_active); builder.assert_bool(cols.flag_half_output); builder.assert_bool(cols.flag_hardcoded_left); + builder.assert_bool(cols.flag_full_output); + // Mutually exclusive: a row cannot be both half-output and full-output. + builder.assert_zero(cols.flag_full_output * cols.flag_half_output); + // When full_output is set, index_input_res_high MUST equal index_res + DIGEST_LEN so that + // outputs_high lands at m[res+8..res+16]. When full_output is unset, the trace generator + // is free to set index_input_res_high to any zero-page address; the lookup will check + // outputs_high (= 0) against m[that_address+i] which is zero by construction. + builder.assert_zero( + cols.flag_full_output + * (cols.index_input_res_high - cols.index_res - AB::F::from_usize(DIGEST_LEN)), + ); builder.assert_zero(cols.flag_hardcoded_left * (cols.offset_hardcoded_left - cols.effective_index_left_first)); builder.assert_zero(one_minus_flag_hardcoded_left * (index_a - cols.effective_index_left_first)); @@ -358,6 +422,15 @@ pub(super) struct Poseidon1Cols16 { pub partial_rounds: [T; PARTIAL_ROUNDS], pub ending_full_rounds: [[T; WIDTH]; HALF_FINAL_FULL_ROUNDS - 1], pub outputs: [T; WIDTH / 2], + /// 1 = expose all 16 perm-output elements (writes outputs_high to m[res+8..res+16]). + /// Mutually exclusive with flag_half_output. + pub flag_full_output: T, + /// Memory address for the high-half outputs. = index_res + DIGEST_LEN when flag_full_output; + /// otherwise points at zero_vec_ptr (a region pre-filled with zeros) so the lookup is a no-op. + pub index_input_res_high: T, + /// High-half perm output (state[8..16] + inputs[8..16]). Constrained when flag_full_output; + /// forced to zero when not, so the lookup against zero_vec_ptr is trivially satisfied. + pub outputs_high: [T; WIDTH / 2], } fn eval_poseidon1_16(builder: &mut AB, local: &Poseidon1Cols16) { @@ -417,9 +490,11 @@ fn eval_poseidon1_16(builder: &mut AB, local: &Poseidon1Cols16( initial_state: &[AB::IF; WIDTH], state: &mut [AB::IF; WIDTH], outputs: &[AB::IF; WIDTH / 2], + outputs_high: &[AB::IF; WIDTH / 2], round_constants_1: &[F; WIDTH], round_constants_2: &[F; WIDTH], flag_half_output: AB::IF, + flag_full_output: AB::IF, builder: &mut AB, ) { for (s, r) in state.iter_mut().zip(round_constants_1.iter()) { @@ -477,20 +554,28 @@ fn eval_last_2_full_rounds_16( *s = s.cube(); } mds_air_16(state); - // add inputs to outputs (for compression) + // add inputs to outputs (for compression / MMO feedforward) for (state_i, init_state_i) in state.iter_mut().zip(initial_state) { *state_i += *init_state_i; } let one_minus_flag_half_output = AB::IF::ONE - flag_half_output; - for (idx, (state_i, output_i)) in state.iter_mut().zip(outputs).enumerate() { + let one_minus_flag_full_output = AB::IF::ONE - flag_full_output; + // First 8 outputs: existing behavior (always 0..4, gated by half on 4..8). + for (idx, (state_i, output_i)) in state.iter().take(WIDTH / 2).zip(outputs).enumerate() { if idx < HALF_DIGEST_LEN { - // First 4 outputs: always constrained builder.assert_eq(*state_i, *output_i); } else { - // Last 4 outputs: constrained only when half_output = 0 builder.assert_zero(one_minus_flag_half_output * (*state_i - *output_i)); } - *state_i = *output_i; + } + // Outputs_high: constrained to state[8..16] when full_output, else forced to zero. + for (state_i, output_high_i) in state.iter().skip(WIDTH / 2).zip(outputs_high) { + builder.assert_zero(flag_full_output * (*state_i - *output_high_i)); + builder.assert_zero(one_minus_flag_full_output * *output_high_i); + } + // Mirror the original "advance state to output" so any downstream code sees the canonical state. + for (idx, output_i) in outputs.iter().enumerate() { + state[idx] = *output_i; } } diff --git a/crates/lean_vm/src/tables/poseidon_16/trace_gen.rs b/crates/lean_vm/src/tables/poseidon_16/trace_gen.rs index fca712257..4180c2170 100644 --- a/crates/lean_vm/src/tables/poseidon_16/trace_gen.rs +++ b/crates/lean_vm/src/tables/poseidon_16/trace_gen.rs @@ -104,6 +104,7 @@ pub(super) fn generate_trace_rows_for_perm + Copy>(perm: & &mut state, &inputs, &mut perm.outputs, + &mut perm.outputs_high, &poseidon1_final_constants()[2 * n_ending_full_rounds], &poseidon1_final_constants()[2 * n_ending_full_rounds + 1], ); @@ -138,6 +139,7 @@ fn generate_last_2_full_rounds + Copy>( state: &mut [F; WIDTH], inputs: &[F; WIDTH], outputs: &mut [&mut F; WIDTH / 2], + outputs_high: &mut [&mut F; WIDTH / 2], round_constants_1: &[KoalaBear; WIDTH], round_constants_2: &[KoalaBear; WIDTH], ) { @@ -153,8 +155,20 @@ fn generate_last_2_full_rounds + Copy>( } mds_circ_16(state); - // Add inputs to outputs (compression) - for ((output, state_i), &input_i) in outputs.iter_mut().zip(state).zip(inputs) { - **output = *state_i + input_i; + // Add inputs to outputs (compression / MMO feedforward). + // First half of state goes into `outputs`; second half into `outputs_high`. + // Note: the AIR forces outputs_high to zero when flag_full_output = 0; the + // lean_prover post-processing pass overwrites these columns to zero for + // non-full-output rows. For full-output rows the values written here are + // exactly what the AIR + lookup expect (state[i+8] + inputs[i+8]). + for (idx, (output, &input_i)) in outputs.iter_mut().zip(inputs.iter().take(WIDTH / 2)).enumerate() { + **output = state[idx] + input_i; + } + for (idx, (output_high, &input_i)) in outputs_high + .iter_mut() + .zip(inputs.iter().skip(WIDTH / 2)) + .enumerate() + { + **output_high = state[idx + WIDTH / 2] + input_i; } } diff --git a/crates/rec_aggregation/zkdsl_implem/hashing.py b/crates/rec_aggregation/zkdsl_implem/hashing.py index f9a0870d6..0459c9c88 100644 --- a/crates/rec_aggregation/zkdsl_implem/hashing.py +++ b/crates/rec_aggregation/zkdsl_implem/hashing.py @@ -114,39 +114,41 @@ def slice_hash_rtl_rate12(data, data_len: Const, padded_len: Const, n_chunks_12: def slice_hash_rtl_rate12_no_pad(padded_data, padded_len: Const, n_chunks_12: Const): - # states[k*8..(k+1)*8] is the 8-element output of round k's compress. - states = Array((n_chunks_12 + 1) * DIGEST_LEN) + """MMO sponge: ADD message into rate, full-state feedforward. - # Round 0: initial state from last 16 elements of padded_data. - poseidon16_compress( + The chaining variable is the FULL 16-element state across rounds, giving + output_bits/2 = 124-bit collision security regardless of capacity. + The output is the first 8 elements of the final 16-element state. + + states[k*16..(k+1)*16] holds the full 16-element state after round k. + """ + states = Array((n_chunks_12 + 1) * 16) + + # Round 0: states[0..16] = padded_data[len-16..len] + perm(padded_data[len-16..len]) + # (zero IV implicit; first absorb feeds the last 16 elements of input as the initial state). + poseidon16_permute( padded_data + padded_len - 16, padded_data + padded_len - 8, states, ) - # Subsequent rounds: absorb 12-element chunks RTL. + # Subsequent rounds: absorb 12-element chunks RTL using MMO compression. + # pre[0..4] = states[j*16..j*16+4] (capacity unchanged) + # pre[4..16] = states[j*16+4..j*16+16] + chunk (rate gets ADDED with chunk) + # states[(j+1)*16..(j+2)*16] = pre + perm(pre) (full-state feedforward) for j in unroll(0, n_chunks_12): chunk_idx = n_chunks_12 - 1 - j - # Build left input (8 elements): [capacity_4 || chunk[0..4]]. - buf = Array(DIGEST_LEN) - buf[0] = states[j * DIGEST_LEN + 0] - buf[1] = states[j * DIGEST_LEN + 1] - buf[2] = states[j * DIGEST_LEN + 2] - buf[3] = states[j * DIGEST_LEN + 3] - buf[4] = padded_data[chunk_idx * 12 + 0] - buf[5] = padded_data[chunk_idx * 12 + 1] - buf[6] = padded_data[chunk_idx * 12 + 2] - buf[7] = padded_data[chunk_idx * 12 + 3] - - # Right input: chunk[4..12]. Output -> states[(j+1)*8..(j+2)*8]. - poseidon16_compress( - buf, - padded_data + chunk_idx * 12 + 4, - states + (j + 1) * DIGEST_LEN, - ) + pre = Array(16) + for k in unroll(0, 4): + pre[k] = states[j * 16 + k] + for k in unroll(0, 12): + pre[4 + k] = states[j * 16 + 4 + k] + padded_data[chunk_idx * 12 + k] + + poseidon16_permute(pre, pre + 8, states + (j + 1) * 16) - return states + n_chunks_12 * DIGEST_LEN + # Output the first 8 elements of the final state. + return states + n_chunks_12 * 16 @inline diff --git a/crates/utils/src/poseidon.rs b/crates/utils/src/poseidon.rs index 3eb2557ba..b8d60d5e5 100644 --- a/crates/utils/src/poseidon.rs +++ b/crates/utils/src/poseidon.rs @@ -24,6 +24,14 @@ pub fn poseidon16_compress(input: [KoalaBear; 16]) -> [KoalaBear; 8] { get_poseidon16().compress(input)[0..8].try_into().unwrap() } +/// Like `poseidon16_compress` but exposes the FULL 16-element output (with +/// input feedforward = MMO compression). Used by the `poseidon16_permute` +/// precompile to support MMO sponge leaf hashing. +#[inline(always)] +pub fn poseidon16_permute_full(input: [KoalaBear; 16]) -> [KoalaBear; 16] { + get_poseidon16().compress(input) +} + pub fn poseidon16_compress_pair(left: &[KoalaBear; 8], right: &[KoalaBear; 8]) -> [KoalaBear; 8] { let mut input = [KoalaBear::default(); 16]; input[..8].copy_from_slice(left); diff --git a/crates/whir/src/merkle.rs b/crates/whir/src/merkle.rs index 119e4d614..89ed04120 100644 --- a/crates/whir/src/merkle.rs +++ b/crates/whir/src/merkle.rs @@ -96,7 +96,7 @@ fn build_merkle_tree_koalabear( 0 }; let first_layer = if n_zero_suffix_rate_chunks >= 2 { - let scalar_state = symetric::precompute_zero_suffix_state::( + let scalar_state = symetric::mmo_precompute_zero_suffix_state::( &perm, n_zero_suffix_rate_chunks, ); @@ -212,12 +212,13 @@ impl, const DIGEST_ELEMS: effective_base_width: usize, ) -> Self where - P: PackedValue + Default, + F: field::PrimeCharacteristicRing, + P: PackedValue + Default + field::PrimeCharacteristicRing, Perm: Compression<[F; WIDTH]> + Compression<[P; WIDTH]>, { let n_zero_suffix_rate_chunks = (full_leaf_base_width - effective_base_width) / RATE; let first_layer = if n_zero_suffix_rate_chunks >= 2 { - let scalar_state = symetric::precompute_zero_suffix_state::( + let scalar_state = symetric::mmo_precompute_zero_suffix_state::( perm, n_zero_suffix_rate_chunks, ); @@ -260,7 +261,7 @@ fn first_digest_layer Vec<[P::Value; DIGEST_ELEMS]> where - P: PackedValue + Default, + P: PackedValue + Default + field::PrimeCharacteristicRing, P::Value: Default + Copy, Perm: Compression<[P::Value; WIDTH]> + Compression<[P; WIDTH]>, M: Matrix, @@ -280,7 +281,7 @@ where let first_row = i * width; let rtl_iter = matrix.vertically_packed_row_rtl::

(first_row, matrix_width, n_trailing_zeros); let packed_digest: [P; DIGEST_ELEMS] = - symetric::hash_rtl_iter::<_, _, _, WIDTH, RATE, DIGEST_ELEMS>(perm, rtl_iter); + symetric::mmo_hash_rtl_iter::<_, _, _, WIDTH, RATE, DIGEST_ELEMS>(perm, rtl_iter); for (dst, src) in digests_chunk.iter_mut().zip(unpack_array(packed_digest)) { *dst = src; } @@ -297,7 +298,7 @@ fn first_digest_layer_with_initial_state Vec<[P::Value; DIGEST_ELEMS]> where - P: PackedValue + Default, + P: PackedValue + Default + field::PrimeCharacteristicRing, P::Value: Default + Copy, Perm: Compression<[P::Value; WIDTH]> + Compression<[P; WIDTH]>, M: Matrix, @@ -316,7 +317,7 @@ where let first_row = i * width; let rtl_iter = matrix.vertically_packed_row_rtl::

(first_row, effective_base_width, n_pad); let packed_digest: [P; DIGEST_ELEMS] = - symetric::hash_rtl_iter_with_initial_state::<_, _, _, WIDTH, RATE, DIGEST_ELEMS>( + symetric::mmo_hash_rtl_iter_with_initial_state::<_, _, _, WIDTH, RATE, DIGEST_ELEMS>( perm, rtl_iter, packed_initial_state, From 602859ad7bceff864264530f42c1b3348874996c Mon Sep 17 00:00:00 2001 From: Barnadrot Date: Sat, 9 May 2026 12:30:51 +0200 Subject: [PATCH 4/5] perf(poseidon): add #[inline] to hot cross-crate functions for thin LTO Under the workspace default thin LTO profile, the new RATE=12 + MMO sponge code introduced cross-crate calls that did not get inlined: mmo_hash_slice, mmo_precompute_zero_suffix_state, compress_mut, permute_mut. The hot loop in build_merkle_tree_koalabear ended up making out-of-line calls into mt_symetric and mt_koala_bear on every absorb, spilling the 16-element state to the stack each iteration. Adding #[inline] makes these functions available for cross-CGU inlining under thin LTO, matching the codegen fat LTO already produces. No semantic change. The functions are short hot-path wrappers/loops that the compiler should inline anyway given the chance. --- crates/backend/koala-bear/src/poseidon1_koalabear_16.rs | 1 + crates/backend/symetric/src/permutation.rs | 1 + crates/backend/symetric/src/sponge.rs | 2 ++ 3 files changed, 4 insertions(+) diff --git a/crates/backend/koala-bear/src/poseidon1_koalabear_16.rs b/crates/backend/koala-bear/src/poseidon1_koalabear_16.rs index f80a9a7b2..cd7784dfb 100644 --- a/crates/backend/koala-bear/src/poseidon1_koalabear_16.rs +++ b/crates/backend/koala-bear/src/poseidon1_koalabear_16.rs @@ -1038,6 +1038,7 @@ impl Poseidon1KoalaBear16 { impl + InjectiveMonomial<3> + Send + Sync + 'static> Permutation<[R; 16]> for Poseidon1KoalaBear16 { + #[inline] fn permute_mut(&self, input: &mut [R; 16]) { // On targets with a SIMD fast path, dispatch to it when R is the arch-specific packed type. #[cfg(any( diff --git a/crates/backend/symetric/src/permutation.rs b/crates/backend/symetric/src/permutation.rs index c129a1dc4..4068f2af3 100644 --- a/crates/backend/symetric/src/permutation.rs +++ b/crates/backend/symetric/src/permutation.rs @@ -16,6 +16,7 @@ pub trait Compression: Clone + Sync { impl + InjectiveMonomial<3> + Send + Sync + 'static> Compression<[R; 16]> for Poseidon1KoalaBear16 { + #[inline] fn compress_mut(&self, input: &mut [R; 16]) { self.compress_in_place(input); } diff --git a/crates/backend/symetric/src/sponge.rs b/crates/backend/symetric/src/sponge.rs index e2cf555ba..aab967146 100644 --- a/crates/backend/symetric/src/sponge.rs +++ b/crates/backend/symetric/src/sponge.rs @@ -142,6 +142,7 @@ where /// MMO-mode (feedforward) variant of `hash_slice`. Same input format and /// alignment requirements; collision security is bounded by the digest size /// rather than the capacity. +#[inline] pub fn mmo_hash_slice(comp: &Comp, data: &[T]) -> [T; OUT] where T: PrimeCharacteristicRing, @@ -170,6 +171,7 @@ where /// MMO-mode variant of `precompute_zero_suffix_state`. Same number of perm /// calls as the standard variant (n_zero_chunks - 1 total). +#[inline] pub fn mmo_precompute_zero_suffix_state( comp: &Comp, n_zero_chunks: usize, From 3441e3a9764875c7be760449bfd26f2891602fb7 Mon Sep 17 00:00:00 2001 From: Barnadrot Date: Sat, 9 May 2026 13:04:46 +0200 Subject: [PATCH 5/5] chore: fix clippy and rustfmt for CI - rustfmt: re-flow long lines introduced by the MMO commit - clippy: replace redundant closures in sponge tests with function refs - clippy: allow too_many_arguments on eval_last_2_full_rounds_16 (AIR helper, 9 args) - clippy: rewrite full_output_flags loop with .iter().enumerate() --- .../koala-bear/src/poseidon1_koalabear_16.rs | 7 +++- crates/backend/symetric/src/sponge.rs | 38 +++++++++++++------ .../lean_compiler/src/a_simplify_lang/mod.rs | 5 +-- crates/lean_prover/src/trace_gen.rs | 8 ++-- crates/lean_vm/src/tables/poseidon_16/mod.rs | 9 +++-- .../src/tables/poseidon_16/trace_gen.rs | 6 +-- crates/whir/src/merkle.rs | 12 +++--- 7 files changed, 51 insertions(+), 34 deletions(-) diff --git a/crates/backend/koala-bear/src/poseidon1_koalabear_16.rs b/crates/backend/koala-bear/src/poseidon1_koalabear_16.rs index cd7784dfb..1072c09d2 100644 --- a/crates/backend/koala-bear/src/poseidon1_koalabear_16.rs +++ b/crates/backend/koala-bear/src/poseidon1_koalabear_16.rs @@ -53,7 +53,12 @@ fn dit>(v: &mut [R; 16], } #[inline(always)] -fn neg_dif>(v: &mut [R; 16], lo: usize, hi: usize, t: KoalaBear) { +fn neg_dif>( + v: &mut [R; 16], + lo: usize, + hi: usize, + t: KoalaBear, +) { let (a, b) = (v[lo], v[hi]); v[lo] = a + b; v[hi] = (b - a) * t; diff --git a/crates/backend/symetric/src/sponge.rs b/crates/backend/symetric/src/sponge.rs index aab967146..e195d2a06 100644 --- a/crates/backend/symetric/src/sponge.rs +++ b/crates/backend/symetric/src/sponge.rs @@ -143,7 +143,10 @@ where /// alignment requirements; collision security is bounded by the digest size /// rather than the capacity. #[inline] -pub fn mmo_hash_slice(comp: &Comp, data: &[T]) -> [T; OUT] +pub fn mmo_hash_slice( + comp: &Comp, + data: &[T], +) -> [T; OUT] where T: PrimeCharacteristicRing, Comp: Compression<[T; WIDTH]>, @@ -260,38 +263,47 @@ where #[cfg(test)] mod tests { use super::*; - use koala_bear::{KoalaBear, default_koalabear_poseidon1_16}; use field::PrimeCharacteristicRing; + use koala_bear::{KoalaBear, default_koalabear_poseidon1_16}; /// Verify hash_slice(D) == hash_rtl_iter(D.iter().rev()) for arbitrary D with valid length. #[test] fn hash_slice_matches_rtl_iter_rate12() { let perm = default_koalabear_poseidon1_16(); // 100 = 16 + 12*7, compatible with WIDTH=16, RATE=12 - let data: Vec = (0..100u32).map(|i| KoalaBear::from_u32(i)).collect(); + let data: Vec = (0..100u32).map(KoalaBear::from_u32).collect(); let h_slice = hash_slice::(&perm, &data); let h_rtl = hash_rtl_iter::(&perm, data.iter().rev().copied()); - assert_eq!(h_slice, h_rtl, "hash_slice and hash_rtl_iter must agree on equivalent inputs"); + assert_eq!( + h_slice, h_rtl, + "hash_slice and hash_rtl_iter must agree on equivalent inputs" + ); } /// Same as above but for the existing RATE=8 case. #[test] fn hash_slice_matches_rtl_iter_rate8() { let perm = default_koalabear_poseidon1_16(); - let data: Vec = (0..64u32).map(|i| KoalaBear::from_u32(i)).collect(); + let data: Vec = (0..64u32).map(KoalaBear::from_u32).collect(); let h_slice = hash_slice::(&perm, &data); let h_rtl = hash_rtl_iter::(&perm, data.iter().rev().copied()); - assert_eq!(h_slice, h_rtl, "hash_slice and hash_rtl_iter must agree on equivalent inputs (RATE=8)"); + assert_eq!( + h_slice, h_rtl, + "hash_slice and hash_rtl_iter must agree on equivalent inputs (RATE=8)" + ); } /// MMO-mode counterpart of hash_slice_matches_rtl_iter_rate12. #[test] fn mmo_hash_slice_matches_rtl_iter_rate12() { let perm = default_koalabear_poseidon1_16(); - let data: Vec = (0..100u32).map(|i| KoalaBear::from_u32(i)).collect(); + let data: Vec = (0..100u32).map(KoalaBear::from_u32).collect(); let h_slice = mmo_hash_slice::(&perm, &data); let h_rtl = mmo_hash_rtl_iter::(&perm, data.iter().rev().copied()); - assert_eq!(h_slice, h_rtl, "mmo_hash_slice and mmo_hash_rtl_iter must agree on equivalent inputs"); + assert_eq!( + h_slice, h_rtl, + "mmo_hash_slice and mmo_hash_rtl_iter must agree on equivalent inputs" + ); } /// MMO-mode is structurally distinct from oSponge — verify they produce @@ -300,10 +312,13 @@ mod tests { #[test] fn mmo_differs_from_standard_sponge() { let perm = default_koalabear_poseidon1_16(); - let data: Vec = (0..28u32).map(|i| KoalaBear::from_u32(i)).collect(); // 16 + 12, two-block input + let data: Vec = (0..28u32).map(KoalaBear::from_u32).collect(); // 16 + 12, two-block input let h_std = hash_slice::(&perm, &data); let h_mmo = mmo_hash_slice::(&perm, &data); - assert_ne!(h_std, h_mmo, "MMO must differ from standard sponge for multi-block inputs"); + assert_ne!( + h_std, h_mmo, + "MMO must differ from standard sponge for multi-block inputs" + ); } /// Verify the MMO precompute is consistent with directly hashing zeros. @@ -311,8 +326,7 @@ mod tests { fn mmo_precompute_zero_suffix_matches_full_zero_hash() { let perm = default_koalabear_poseidon1_16(); let n_zero_chunks: usize = 4; // WIDTH absorb + 3 RATE absorbs of zero - let zeros: Vec = - std::iter::repeat_n(KoalaBear::ZERO, 16 + 12 * (n_zero_chunks - 1)).collect(); + let zeros: Vec = std::iter::repeat_n(KoalaBear::ZERO, 16 + 12 * (n_zero_chunks - 1)).collect(); let direct = mmo_hash_slice::(&perm, &zeros); let pre = mmo_precompute_zero_suffix_state::(&perm, n_zero_chunks); // The precompute does (n_zero_chunks - 1) MMO compressions; mmo_hash_slice diff --git a/crates/lean_compiler/src/a_simplify_lang/mod.rs b/crates/lean_compiler/src/a_simplify_lang/mod.rs index 451d5ab5f..cba454bf2 100644 --- a/crates/lean_compiler/src/a_simplify_lang/mod.rs +++ b/crates/lean_compiler/src/a_simplify_lang/mod.rs @@ -7,9 +7,8 @@ use crate::{ use backend::PrimeCharacteristicRing; use lean_vm::{ ALL_POSEIDON16_NAMES, Boolean, BooleanExpr, CustomHint, ExtensionOpMode, FunctionName, - POSEIDON16_HALF_HARDCODED_LEFT_NAME, POSEIDON16_HALF_NAME, POSEIDON16_HARDCODED_LEFT_NAME, - POSEIDON16_PERMUTE_NAME, PrecompileArgs, - PrecompileCompTimeArgs, SourceLocation, + POSEIDON16_HALF_HARDCODED_LEFT_NAME, POSEIDON16_HALF_NAME, POSEIDON16_HARDCODED_LEFT_NAME, POSEIDON16_PERMUTE_NAME, + PrecompileArgs, PrecompileCompTimeArgs, SourceLocation, }; use std::{ collections::{BTreeMap, BTreeSet}, diff --git a/crates/lean_prover/src/trace_gen.rs b/crates/lean_prover/src/trace_gen.rs index fe3c4eed7..f3b0be56b 100644 --- a/crates/lean_prover/src/trace_gen.rs +++ b/crates/lean_prover/src/trace_gen.rs @@ -135,12 +135,10 @@ pub fn get_execution_trace(bytecode: &Bytecode, execution_result: ExecutionResul // (m[zero_vec_ptr + i] = 0 = outputs_high[i]). { // Snapshot flag column (immutable copy) before taking mutable references to the trace. - let full_output_flags: Vec = - poseidon_trace.columns[POSEIDON_16_COL_FLAG_FULL_OUTPUT].clone(); + let full_output_flags: Vec = poseidon_trace.columns[POSEIDON_16_COL_FLAG_FULL_OUTPUT].clone(); let zero_ptr = F::from_usize(padding_zero_vec_ptr); - let n_rows = full_output_flags.len(); - for row_idx in 0..n_rows { - if full_output_flags[row_idx] != F::ONE { + for (row_idx, flag) in full_output_flags.iter().enumerate() { + if *flag != F::ONE { poseidon_trace.columns[POSEIDON_16_COL_INDEX_INPUT_RES_HIGH][row_idx] = zero_ptr; for j in 0..DIGEST_LEN { poseidon_trace.columns[POSEIDON_16_COL_OUTPUTS_HIGH_START + j][row_idx] = F::ZERO; diff --git a/crates/lean_vm/src/tables/poseidon_16/mod.rs b/crates/lean_vm/src/tables/poseidon_16/mod.rs index 35e2b83b1..af1b8b9bc 100644 --- a/crates/lean_vm/src/tables/poseidon_16/mod.rs +++ b/crates/lean_vm/src/tables/poseidon_16/mod.rs @@ -252,7 +252,10 @@ impl TableT for Poseidon16Precompile { else { unreachable!("Poseidon16 table called with non-Poseidon16 args"); }; - debug_assert!(!(half_output && full_output), "half_output and full_output are mutually exclusive"); + debug_assert!( + !(half_output && full_output), + "half_output and full_output are mutually exclusive" + ); let trace = ctx.traces.get_mut(&self.table()).unwrap(); let arg_a_usize = arg_a.to_usize(); @@ -394,8 +397,7 @@ impl Air for Poseidon16Precompile { // is free to set index_input_res_high to any zero-page address; the lookup will check // outputs_high (= 0) against m[that_address+i] which is zero by construction. builder.assert_zero( - cols.flag_full_output - * (cols.index_input_res_high - cols.index_res - AB::F::from_usize(DIGEST_LEN)), + cols.flag_full_output * (cols.index_input_res_high - cols.index_res - AB::F::from_usize(DIGEST_LEN)), ); builder.assert_zero(cols.flag_hardcoded_left * (cols.offset_hardcoded_left - cols.effective_index_left_first)); @@ -533,6 +535,7 @@ fn eval_2_full_rounds_16( } #[inline] +#[allow(clippy::too_many_arguments)] fn eval_last_2_full_rounds_16( initial_state: &[AB::IF; WIDTH], state: &mut [AB::IF; WIDTH], diff --git a/crates/lean_vm/src/tables/poseidon_16/trace_gen.rs b/crates/lean_vm/src/tables/poseidon_16/trace_gen.rs index 4180c2170..c1350e02c 100644 --- a/crates/lean_vm/src/tables/poseidon_16/trace_gen.rs +++ b/crates/lean_vm/src/tables/poseidon_16/trace_gen.rs @@ -164,11 +164,7 @@ fn generate_last_2_full_rounds + Copy>( for (idx, (output, &input_i)) in outputs.iter_mut().zip(inputs.iter().take(WIDTH / 2)).enumerate() { **output = state[idx] + input_i; } - for (idx, (output_high, &input_i)) in outputs_high - .iter_mut() - .zip(inputs.iter().skip(WIDTH / 2)) - .enumerate() - { + for (idx, (output_high, &input_i)) in outputs_high.iter_mut().zip(inputs.iter().skip(WIDTH / 2)).enumerate() { **output_high = state[idx + WIDTH / 2] + input_i; } } diff --git a/crates/whir/src/merkle.rs b/crates/whir/src/merkle.rs index 89ed04120..6b1758a44 100644 --- a/crates/whir/src/merkle.rs +++ b/crates/whir/src/merkle.rs @@ -96,10 +96,11 @@ fn build_merkle_tree_koalabear( 0 }; let first_layer = if n_zero_suffix_rate_chunks >= 2 { - let scalar_state = symetric::mmo_precompute_zero_suffix_state::( - &perm, - n_zero_suffix_rate_chunks, - ); + let scalar_state = + symetric::mmo_precompute_zero_suffix_state::( + &perm, + n_zero_suffix_rate_chunks, + ); let packed_state: [PFPacking; SPONGE_WIDTH] = std::array::from_fn(|i| PFPacking::::from_fn(|_| scalar_state[i])); first_digest_layer_with_initial_state::, _, _, DIGEST_ELEMS, SPONGE_WIDTH, SPONGE_RATE>( @@ -115,7 +116,8 @@ fn build_merkle_tree_koalabear( padded_full_width, ) }; - let tree = symetric::merkle::MerkleTree::from_first_layer::, _, SPONGE_WIDTH>(&perm, first_layer); + let tree = + symetric::merkle::MerkleTree::from_first_layer::, _, SPONGE_WIDTH>(&perm, first_layer); // Expose UNPADDED width to the protocol; padding is purely a sponge detail. WhirMerkleTree { leaf,