diff --git a/crates/lean_prover/src/prove_execution.rs b/crates/lean_prover/src/prove_execution.rs index 91cae5a5..a64acc0f 100644 --- a/crates/lean_prover/src/prove_execution.rs +++ b/crates/lean_prover/src/prove_execution.rs @@ -91,10 +91,14 @@ pub fn prove_execution( let mut memory_acc = F::zero_vec(memory.len()); info_span!("Building memory access count").in_scope(|| { for (table, trace) in &traces { - for lookup in table.lookups() { - for i in &trace.columns[lookup.index] { - for j in 0..lookup.values.len() { - memory_acc[i.to_usize() + j] += F::ONE; + let buses = table.bus_interactions(); + for group in memory_lookup_groups(&buses) { + let idx_col = &trace.columns[group.idx_col]; + let n = group.value_cols.len(); + for idx in idx_col { + let base = idx.to_usize(); + for ofs in 0..n { + memory_acc[base + ofs] += F::ONE; } } } @@ -122,7 +126,7 @@ pub fn prove_execution( // logup (GKR) let logup_c = prover_state.sample(); prover_state.duplex(); - let logup_alphas = prover_state.sample_vec(log2_ceil_usize(max_bus_width_including_bytecode())); + let logup_alphas = prover_state.sample_vec(LOG_MAX_BUS_WIDTH); let logup_alphas_eq_poly = eval_eq(&logup_alphas); let logup_statements = prove_generic_logup( @@ -181,11 +185,11 @@ pub fn prove_execution( let bus_numerator_value = logup_statements.bus_numerators_values[table]; let bus_denominator_value = logup_statements.bus_denominators_values[table]; let bus_final_value = bus_numerator_value - * match table.bus().direction { + * match table.bus_interactions()[0].direction { BusDirection::Pull => EF::NEG_ONE, BusDirection::Push => EF::ONE, } - + bus_beta * (bus_denominator_value - logup_c); + + bus_beta * (logup_c - bus_denominator_value); let eq_suffix = from_end(gkr_point, *log_n_rows).to_vec(); diff --git a/crates/lean_prover/src/verify_execution.rs b/crates/lean_prover/src/verify_execution.rs index 8781c6e7..b12c7d0a 100644 --- a/crates/lean_prover/src/verify_execution.rs +++ b/crates/lean_prover/src/verify_execution.rs @@ -78,7 +78,7 @@ pub fn verify_execution( let logup_c = verifier_state.sample(); verifier_state.duplex(); - let logup_alphas = verifier_state.sample_vec(log2_ceil_usize(max_bus_width_including_bytecode())); + let logup_alphas = verifier_state.sample_vec(LOG_MAX_BUS_WIDTH); let logup_alphas_eq_poly = eval_eq(&logup_alphas); let logup_statements = verify_generic_logup( @@ -126,11 +126,11 @@ pub fn verify_execution( let bus_numerator_value = logup_statements.bus_numerators_values[table]; let bus_denominator_value = logup_statements.bus_denominators_values[table]; let bus_final_value = bus_numerator_value - * match table.bus().direction { + * match table.bus_interactions()[0].direction { BusDirection::Pull => EF::NEG_ONE, BusDirection::Push => EF::ONE, } - + bus_beta * (bus_denominator_value - logup_c); + + bus_beta * (logup_c - bus_denominator_value); initial_sum += eta_power * bus_final_value; diff --git a/crates/lean_vm/src/core/constants.rs b/crates/lean_vm/src/core/constants.rs index 6c93cee5..50b9371e 100644 --- a/crates/lean_vm/src/core/constants.rs +++ b/crates/lean_vm/src/core/constants.rs @@ -48,7 +48,7 @@ mod tests { #[test] fn ensure_no_overflow_in_logup() { fn memory_lookups_count(t: &T) -> usize { - t.lookups().iter().map(|l| l.values.len()).sum::() + t.bus_interactions().iter().filter(|bus| bus.is_memory_lookup()).count() } // memory lookup let mut max_memory_logup_sum: u64 = 0; diff --git a/crates/lean_vm/src/tables/execution/mod.rs b/crates/lean_vm/src/tables/execution/mod.rs index c866f38b..852099fb 100644 --- a/crates/lean_vm/src/tables/execution/mod.rs +++ b/crates/lean_vm/src/tables/execution/mod.rs @@ -25,34 +25,33 @@ impl TableT for ExecutionTable { N_TOTAL_EXECUTION_COLUMNS + N_TEMPORARY_EXEC_COLUMNS } - fn lookups(&self) -> Vec { - vec![ - LookupIntoMemory { - index: COL_MEM_ADDRESS_A, - values: vec![COL_MEM_VALUE_A], - }, - LookupIntoMemory { - index: COL_MEM_ADDRESS_B, - values: vec![COL_MEM_VALUE_B], - }, - LookupIntoMemory { - index: COL_MEM_ADDRESS_C, - values: vec![COL_MEM_VALUE_C], - }, - ] - } - - fn bus(&self) -> Bus { - Bus { + fn bus_interactions(&self) -> Vec { + let bytecode_lookup = BusInteraction { + direction: BusDirection::Push, + multiplicity: BusMultiplicity::One, + domainsep: BusData::Constant(LOGUP_BYTECODE_DOMAINSEP), + data: (0..N_INSTRUCTION_COLUMNS) + .map(|i| BusData::Column(N_RUNTIME_COLUMNS + i)) + .chain(std::iter::once(BusData::Column(COL_PC))) + .collect(), + }; + let precompile_bus = BusInteraction { direction: BusDirection::Push, - multiplicity: COL_IS_PRECOMPILE, + multiplicity: BusMultiplicity::Column(COL_IS_PRECOMPILE), domainsep: BusData::Column(COL_PRECOMPILE_DOMAINSEP), data: vec![ BusData::Column(COL_EXEC_NU_A), BusData::Column(COL_EXEC_NU_B), BusData::Column(COL_EXEC_NU_C), ], - } + }; + // Convention shared with the other tables: the unique Multiplicity::Column bus + // comes first; everything that follows is Multiplicity::One. + let mut buses = vec![precompile_bus, bytecode_lookup]; + buses.extend(memory_lookups_consecutive(COL_MEM_ADDRESS_A, COL_MEM_VALUE_A, 1)); + buses.extend(memory_lookups_consecutive(COL_MEM_ADDRESS_B, COL_MEM_VALUE_B, 1)); + buses.extend(memory_lookups_consecutive(COL_MEM_ADDRESS_C, COL_MEM_VALUE_C, 1)); + buses } fn padding_row(&self, zero_vec_ptr: usize, _null_hash_ptr: usize, ending_pc: usize) -> Vec { diff --git a/crates/lean_vm/src/tables/extension_op/mod.rs b/crates/lean_vm/src/tables/extension_op/mod.rs index df113de9..a8aced24 100644 --- a/crates/lean_vm/src/tables/extension_op/mod.rs +++ b/crates/lean_vm/src/tables/extension_op/mod.rs @@ -87,34 +87,21 @@ impl TableT for ExtensionOpPrecompile { Table::extension_op() } - fn lookups(&self) -> Vec { - vec![ - LookupIntoMemory { - index: COL_IDX_A, - values: (COL_VA..COL_VA + DIMENSION).collect(), - }, - LookupIntoMemory { - index: COL_IDX_B, - values: (COL_VB..COL_VB + DIMENSION).collect(), - }, - LookupIntoMemory { - index: COL_IDX_RES, - values: (COL_VRES..COL_VRES + DIMENSION).collect(), - }, - ] - } - - fn bus(&self) -> Bus { - Bus { + fn bus_interactions(&self) -> Vec { + let mut buses = vec![BusInteraction { direction: BusDirection::Pull, - multiplicity: COL_MULTIPLICITY_EXTENSION_OP, + multiplicity: BusMultiplicity::Column(COL_MULTIPLICITY_EXTENSION_OP), domainsep: BusData::Column(COL_DOMAINSEP_EXTENSION_OP), data: vec![ BusData::Column(COL_IDX_A), BusData::Column(COL_IDX_B), BusData::Column(COL_IDX_RES), ], - } + }]; + buses.extend(memory_lookups_consecutive(COL_IDX_A, COL_VA, DIMENSION)); + buses.extend(memory_lookups_consecutive(COL_IDX_B, COL_VB, DIMENSION)); + buses.extend(memory_lookups_consecutive(COL_IDX_RES, COL_VRES, DIMENSION)); + buses } fn n_columns_total(&self) -> usize { diff --git a/crates/lean_vm/src/tables/poseidon_16/mod.rs b/crates/lean_vm/src/tables/poseidon_16/mod.rs index f8eb18fc..54fc7c86 100644 --- a/crates/lean_vm/src/tables/poseidon_16/mod.rs +++ b/crates/lean_vm/src/tables/poseidon_16/mod.rs @@ -138,44 +138,42 @@ impl TableT for Poseidon16Precompile { Table::poseidon16() } - fn lookups(&self) -> Vec { - vec![ - LookupIntoMemory { - index: POSEIDON_16_COL_EFFECTIVE_INDEX_LEFT_FIRST, - values: (POSEIDON_16_COL_INPUT_START..POSEIDON_16_COL_INPUT_START + HALF_DIGEST_LEN).collect(), - }, - LookupIntoMemory { - index: POSEIDON_16_COL_EFFECTIVE_INDEX_LEFT_SECOND, - values: (POSEIDON_16_COL_INPUT_START + HALF_DIGEST_LEN..POSEIDON_16_COL_INPUT_START + DIGEST_LEN) - .collect(), - }, - LookupIntoMemory { - index: POSEIDON_16_COL_INDEX_INPUT_RIGHT, - values: (POSEIDON_16_COL_INPUT_START + DIGEST_LEN..POSEIDON_16_COL_INPUT_START + DIGEST_LEN * 2) - .collect(), - }, - LookupIntoMemory { - index: POSEIDON_16_COL_INDEX_INPUT_RES, - values: (POSEIDON_16_COL_OUTPUT_LEFT..POSEIDON_16_COL_OUTPUT_LEFT + DIGEST_LEN * 2).collect(), - }, - ] - } - fn n_columns_total(&self) -> usize { num_cols_total_poseidon_16() } - fn bus(&self) -> Bus { - Bus { + fn bus_interactions(&self) -> Vec { + let mut buses = vec![BusInteraction { direction: BusDirection::Pull, - multiplicity: POSEIDON_16_COL_MULTIPLICITY, + multiplicity: BusMultiplicity::Column(POSEIDON_16_COL_MULTIPLICITY), domainsep: BusData::Column(POSEIDON_16_COL_DOMAINSEP), data: vec![ BusData::Column(POSEIDON_16_COL_INDEX_INPUT_LEFT), BusData::Column(POSEIDON_16_COL_INDEX_INPUT_RIGHT), BusData::Column(POSEIDON_16_COL_INDEX_INPUT_RES), ], - } + }]; + buses.extend(memory_lookups_consecutive( + POSEIDON_16_COL_EFFECTIVE_INDEX_LEFT_FIRST, + POSEIDON_16_COL_INPUT_START, + HALF_DIGEST_LEN, + )); + buses.extend(memory_lookups_consecutive( + POSEIDON_16_COL_EFFECTIVE_INDEX_LEFT_SECOND, + POSEIDON_16_COL_INPUT_START + HALF_DIGEST_LEN, + HALF_DIGEST_LEN, + )); + buses.extend(memory_lookups_consecutive( + POSEIDON_16_COL_INDEX_INPUT_RIGHT, + POSEIDON_16_COL_INPUT_START + DIGEST_LEN, + DIGEST_LEN, + )); + buses.extend(memory_lookups_consecutive( + POSEIDON_16_COL_INDEX_INPUT_RES, + POSEIDON_16_COL_OUTPUT_LEFT, + DIGEST_LEN * 2, + )); + buses } fn padding_row(&self, zero_vec_ptr: usize, null_hash_ptr: usize, _ending_pc: usize) -> Vec { diff --git a/crates/lean_vm/src/tables/table_enum.rs b/crates/lean_vm/src/tables/table_enum.rs index de55ecdc..ebd6c8e8 100644 --- a/crates/lean_vm/src/tables/table_enum.rs +++ b/crates/lean_vm/src/tables/table_enum.rs @@ -5,7 +5,8 @@ use crate::*; pub const N_TABLES: usize = 3; pub const ALL_TABLES: [Table; N_TABLES] = [Table::execution(), Table::extension_op(), Table::poseidon16()]; -pub const MAX_PRECOMPILE_BUS_WIDTH: usize = 4; +pub const MAX_BUS_WIDTH: usize = N_INSTRUCTION_COLUMNS + 2; // + 1 for PC, + 1 for domainsep +pub const LOG_MAX_BUS_WIDTH: usize = log2_ceil_usize(MAX_BUS_WIDTH); #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] #[repr(usize)] @@ -60,14 +61,11 @@ impl TableT for Table { fn table(&self) -> Table { delegate_to_inner!(self, table) } - fn lookups(&self) -> Vec { - delegate_to_inner!(self, lookups) - } fn is_execution_table(&self) -> bool { delegate_to_inner!(self, is_execution_table) } - fn bus(&self) -> Bus { - delegate_to_inner!(self, bus) + fn bus_interactions(&self) -> Vec { + delegate_to_inner!(self, bus_interactions) } fn padding_row(&self, zero_vec_ptr: usize, null_hash_ptr: usize, ending_pc: usize) -> Vec> { delegate_to_inner!(self, padding_row, zero_vec_ptr, null_hash_ptr, ending_pc) @@ -106,12 +104,6 @@ impl Air for Table { } } -pub fn max_bus_width_including_bytecode() -> usize { - MAX_PRECOMPILE_BUS_WIDTH - .max(N_INSTRUCTION_COLUMNS + 2) // 2 = +1 for PC, +1 for domainsep - .next_power_of_two() -} - pub fn max_air_constraints() -> usize { ALL_TABLES.iter().map(|table| table.n_constraints()).max().unwrap() } @@ -128,9 +120,13 @@ mod tests { } #[test] - fn test_max_precompile_bus_width() { - // +1 for the domainsep - let expected_max_bus_width = ALL_TABLES.iter().map(|table| table.bus().data.len() + 1).max().unwrap(); - assert_eq!(MAX_PRECOMPILE_BUS_WIDTH, expected_max_bus_width); + fn test_max_bus_width() { + let expected_max_bus_width = ALL_TABLES + .iter() + .flat_map(|table| table.bus_interactions()) + .map(|bus| bus.data.len() + 1) + .max() + .unwrap(); + assert_eq!(MAX_BUS_WIDTH, expected_max_bus_width); } } diff --git a/crates/lean_vm/src/tables/table_trait.rs b/crates/lean_vm/src/tables/table_trait.rs index 38d8f76f..87c821c9 100644 --- a/crates/lean_vm/src/tables/table_trait.rs +++ b/crates/lean_vm/src/tables/table_trait.rs @@ -1,5 +1,5 @@ use crate::execution::memory::MemoryAccess; -use crate::{EF, F, InstructionContext, PrecompileCompTimeArgs, RunnerError, Table}; +use crate::{EF, F, InstructionContext, LOGUP_MEMORY_DOMAINSEP, PrecompileCompTimeArgs, RunnerError, Table}; use backend::*; use std::{any::TypeId, cmp::Reverse, collections::BTreeMap, mem::transmute}; @@ -11,13 +11,6 @@ pub type ColIndex = usize; pub type CommittedStatements = BTreeMap, BTreeMap, BTreeMap)>>; -#[derive(Debug)] -pub struct LookupIntoMemory { - pub index: ColIndex, // should be in base field columns - /// For (i, col_index) in values.iter().enumerate(), For j in 0..num_rows, columns_f[col_index][j] = memory[index[j] + i] - pub values: Vec, -} - #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum BusDirection { Pull, @@ -36,17 +29,113 @@ impl BusDirection { #[derive(Debug, Clone, Copy)] pub enum BusData { Column(ColIndex), + ColumnPlusConstant(ColIndex, usize), Constant(usize), } +impl BusData { + pub fn column(self) -> Option { + match self { + Self::Column(c) | Self::ColumnPlusConstant(c, _) => Some(c), + Self::Constant(_) => None, + } + } +} + +#[derive(Debug, Clone, Copy)] +pub enum BusMultiplicity { + One, + Column(ColIndex), +} + #[derive(Debug)] -pub struct Bus { +pub struct BusInteraction { pub direction: BusDirection, - pub multiplicity: ColIndex, + pub multiplicity: BusMultiplicity, pub domainsep: BusData, pub data: Vec, } +impl BusInteraction { + pub fn is_memory_lookup(&self) -> bool { + matches!(self.domainsep, BusData::Constant(LOGUP_MEMORY_DOMAINSEP)) + } +} + +pub fn memory_lookups_consecutive(idx_col: ColIndex, values_start: ColIndex, n: usize) -> Vec { + (0..n) + .map(|i| BusInteraction { + direction: BusDirection::Push, + multiplicity: BusMultiplicity::One, + domainsep: BusData::Constant(LOGUP_MEMORY_DOMAINSEP), + data: vec![ + BusData::ColumnPlusConstant(idx_col, i), + BusData::Column(values_start + i), + ], + }) + .collect() +} + +pub fn memory_lookup_groups(buses: &[BusInteraction]) -> Vec { + let mut groups: Vec = Vec::new(); + let mut i = 0; + while i < buses.len() { + if !buses[i].is_memory_lookup() { + i += 1; + continue; + } + let (idx_col, first_ofs) = match buses[i].data[0] { + BusData::ColumnPlusConstant(c, ofs) => (c, ofs), + _ => unreachable!("memory-lookup bus shape is enforced by memory_lookups_consecutive"), + }; + if first_ofs != 0 { + let value_col = match buses[i].data[1] { + BusData::Column(c) => c, + _ => unreachable!("memory-lookup bus shape is enforced by memory_lookups_consecutive"), + }; + groups.push(MemoryLookupGroup { + start_bus: i, + idx_col, + value_cols: vec![value_col], + }); + i += 1; + continue; + } + let mut value_cols = Vec::new(); + let start = i; + let mut expected_ofs = 0; + while i < buses.len() && buses[i].is_memory_lookup() { + let ok = matches!( + buses[i].data[0], + BusData::ColumnPlusConstant(c, ofs) if c == idx_col && ofs == expected_ofs + ); + if !ok { + break; + } + let value_col = match buses[i].data[1] { + BusData::Column(c) => c, + _ => unreachable!("memory-lookup bus shape is enforced by memory_lookups_consecutive"), + }; + value_cols.push(value_col); + i += 1; + expected_ofs += 1; + } + groups.push(MemoryLookupGroup { + start_bus: start, + idx_col, + value_cols, + }); + } + groups +} + +#[derive(Debug)] +pub struct MemoryLookupGroup { + pub start_bus: usize, + pub idx_col: ColIndex, + pub value_cols: Vec, +} + #[derive(Debug, Default)] pub struct TableTrace { pub columns: Vec>, @@ -125,8 +214,7 @@ impl>> ExtraDataForBuses { pub trait TableT: Air { fn name(&self) -> &'static str; fn table(&self) -> Table; - fn lookups(&self) -> Vec; - fn bus(&self) -> Bus; + fn bus_interactions(&self) -> Vec; fn padding_row(&self, zero_vec_ptr: usize, null_hash_ptr: usize, ending_pc: usize) -> Vec; fn execute( &self, @@ -145,18 +233,4 @@ pub trait TableT: Air { fn is_execution_table(&self) -> bool { false } - - fn lookup_index_columns<'a>(&'a self, trace: &'a TableTrace) -> Vec<&'a [F]> { - self.lookups() - .iter() - .map(|lookup| &trace.columns[lookup.index][..]) - .collect() - } - fn lookup_value_columns<'a>(&self, trace: &'a TableTrace) -> Vec> { - let mut cols = Vec::new(); - for lookup in self.lookups() { - cols.push(lookup.values.iter().map(|&c| &trace.columns[c][..]).collect()); - } - cols - } } diff --git a/crates/rec_aggregation/src/compilation.rs b/crates/rec_aggregation/src/compilation.rs index 61ebae06..e0040d6e 100644 --- a/crates/rec_aggregation/src/compilation.rs +++ b/crates/rec_aggregation/src/compilation.rs @@ -5,7 +5,7 @@ use lean_prover::{ WHIR_SUBSEQUENT_FOLDING_FACTOR, default_whir_config, }; use lean_vm::*; -use std::collections::{BTreeMap, HashMap}; +use std::collections::{BTreeMap, HashMap, HashSet}; use std::sync::OnceLock; use sub_protocols::{N_VARS_TO_SEND_GKR_COEFFS, min_stacked_n_vars, total_whir_statements}; use tracing::instrument; @@ -245,7 +245,7 @@ fn build_replacements(log_inner_bytecode: usize, bytecode_zero_eval: F) -> BTree ); replacements.insert( "MAX_BUS_WIDTH_PLACEHOLDER".to_string(), - max_bus_width_including_bytecode().to_string(), + (1 << LOG_MAX_BUS_WIDTH).to_string(), ); replacements.insert( "LOGUP_MEMORY_DOMAINSEP_PLACEHOLDER".to_string(), @@ -266,47 +266,81 @@ fn build_replacements(log_inner_bytecode: usize, bytecode_zero_eval: F) -> BTree bytecode_reduction_sumcheck_proof_size(bytecode_point_n_vars).to_string(), ); - let mut lookup_indexes_str = vec![]; - let mut lookup_values_str = vec![]; + let mut one_buses_domseps = vec![]; + let mut one_buses_data_cols = vec![]; + let mut one_buses_data_offsets = vec![]; + let mut one_buses_new_cols = vec![]; let mut num_cols_air = vec![]; let mut air_degrees = vec![]; let mut n_air_columns = vec![]; let mut n_air_shift_columns = vec![]; for table in ALL_TABLES { - let this_look_f_indexes_str = table - .lookups() - .iter() - .map(|lookup_f| lookup_f.index.to_string()) - .collect::>(); - lookup_indexes_str.push(format!("[{}]", this_look_f_indexes_str.join(", "))); + let mut table_domseps = vec![]; + let mut table_data_cols = vec![]; + let mut table_data_offsets = vec![]; + let mut table_new_cols = vec![]; + let mut seen_cols: HashSet = HashSet::new(); + for bus in table.bus_interactions() { + if !matches!(bus.multiplicity, BusMultiplicity::One) { + continue; + } + let BusData::Constant(domsep) = bus.domainsep else { + panic!("Multiplicity::One bus domsep must be a constant"); + }; + let mut data_cols = vec![]; + let mut data_offsets = vec![]; + let mut new_cols = vec![]; + for entry in &bus.data { + let (col, ofs) = match entry { + BusData::Column(c) => (*c, 0), + BusData::ColumnPlusConstant(c, o) => (*c, *o), + BusData::Constant(_) => panic!("Multiplicity::One bus data must be a column"), + }; + data_cols.push(col); + data_offsets.push(ofs); + if seen_cols.insert(col) { + new_cols.push(col); + } + } + table_domseps.push(domsep.to_string()); + table_data_cols.push(format!( + "[{}]", + data_cols.iter().map(usize::to_string).collect::>().join(", ") + )); + table_data_offsets.push(format!( + "[{}]", + data_offsets.iter().map(usize::to_string).collect::>().join(", ") + )); + table_new_cols.push(format!( + "[{}]", + new_cols.iter().map(usize::to_string).collect::>().join(", ") + )); + } + one_buses_domseps.push(format!("[{}]", table_domseps.join(", "))); + one_buses_data_cols.push(format!("[{}]", table_data_cols.join(", "))); + one_buses_data_offsets.push(format!("[{}]", table_data_offsets.join(", "))); + one_buses_new_cols.push(format!("[{}]", table_new_cols.join(", "))); + num_cols_air.push(table.n_columns().to_string()); - let this_lookup_f_values_str = table - .lookups() - .iter() - .map(|lookup_f| { - format!( - "[{}]", - lookup_f - .values - .iter() - .map(|v| v.to_string()) - .collect::>() - .join(", ") - ) - }) - .collect::>(); - lookup_values_str.push(format!("[{}]", this_lookup_f_values_str.join(", "))); air_degrees.push(table.degree_air().to_string()); n_air_columns.push(table.n_columns().to_string()); n_air_shift_columns.push(table.n_shift_columns().to_string()); } replacements.insert( - "LOOKUPS_INDEXES_PLACEHOLDER".to_string(), - format!("[{}]", lookup_indexes_str.join(", ")), + "ONE_BUSES_DOMSEPS_PLACEHOLDER".to_string(), + format!("[{}]", one_buses_domseps.join(", ")), + ); + replacements.insert( + "ONE_BUSES_DATA_COLS_PLACEHOLDER".to_string(), + format!("[{}]", one_buses_data_cols.join(", ")), + ); + replacements.insert( + "ONE_BUSES_DATA_OFFSETS_PLACEHOLDER".to_string(), + format!("[{}]", one_buses_data_offsets.join(", ")), ); replacements.insert( - "LOOKUPS_VALUES_PLACEHOLDER".to_string(), - format!("[{}]", lookup_values_str.join(", ")), + "ONE_BUSES_NEW_COLS_PLACEHOLDER".to_string(), + format!("[{}]", one_buses_new_cols.join(", ")), ); replacements.insert( "NUM_COLS_AIR_PLACEHOLDER".to_string(), @@ -488,7 +522,7 @@ where res += &format!( "\n bus_res: Mut = add_extension_ret(mul_extension_ret({}, logup_alphas_eq_poly + {} * DIM), bus_res_init)", domainsep_str, - max_bus_width_including_bytecode() - 1 + (1 << LOG_MAX_BUS_WIDTH) - 1 ); res += "\n bus_res = mul_extension_ret(bus_res, bus_beta)"; res += &format!("\n sum: Mut = add_extension_ret(bus_res, {})", multiplicity); diff --git a/crates/rec_aggregation/zkdsl_implem/recursion.py b/crates/rec_aggregation/zkdsl_implem/recursion.py index 114d65ac..783e1fa8 100644 --- a/crates/rec_aggregation/zkdsl_implem/recursion.py +++ b/crates/rec_aggregation/zkdsl_implem/recursion.py @@ -18,8 +18,10 @@ LOGUP_BYTECODE_DOMAINSEP = LOGUP_BYTECODE_DOMAINSEP_PLACEHOLDER EXECUTION_TABLE_INDEX = EXECUTION_TABLE_INDEX_PLACEHOLDER -LOOKUPS_INDEXES = LOOKUPS_INDEXES_PLACEHOLDER # [[_; ?]; N_TABLES] -LOOKUPS_VALUES = LOOKUPS_VALUES_PLACEHOLDER # [[[_; ?]; ?]; N_TABLES] +ONE_BUSES_DOMSEPS = ONE_BUSES_DOMSEPS_PLACEHOLDER # [[_; num_buses]; N_TABLES] +ONE_BUSES_DATA_COLS = ONE_BUSES_DATA_COLS_PLACEHOLDER # [[[_; num_data]; num_buses]; N_TABLES] +ONE_BUSES_DATA_OFFSETS = ONE_BUSES_DATA_OFFSETS_PLACEHOLDER # [[[_; num_data]; num_buses]; N_TABLES] +ONE_BUSES_NEW_COLS = ONE_BUSES_NEW_COLS_PLACEHOLDER # [[[_; n_new]; num_buses]; N_TABLES] NUM_COLS_AIR = NUM_COLS_AIR_PLACEHOLDER @@ -114,7 +116,7 @@ def recursion(inner_public_memory, bytecode_hash_domsep): retrieved_numerators_value: Mut = opposite_extension_ret(mul_extension_ret(memory_and_acc_prefix, value_acc)) value_index = mle_of_01234567_etc(point_gkr + (n_vars_logup_gkr - log_memory) * DIM, log_memory) - fingerprint_memory = fingerprint_2(LOGUP_MEMORY_DOMAINSEP, value_memory, value_index, logup_alphas_eq_poly) + fingerprint_memory = fingerprint_2(LOGUP_MEMORY_DOMAINSEP, value_index, value_memory, logup_alphas_eq_poly) retrieved_denominators_value: Mut = mul_extension_ret(memory_and_acc_prefix, sub_extension_ret(logup_c, fingerprint_memory)) offset: Mut = two_exp(log_memory) @@ -278,53 +280,27 @@ def continue_recursion_ordered( bus_numerators_values = DynArray([]) bus_denominators_values = DynArray([]) pcs_points = DynArray([]) # [[_; N]; N_TABLES] - for i in unroll(0, N_TABLES): - pcs_points.push(DynArray([])) pcs_values = DynArray([]) # [[[[] or [_]; num cols]; N]; N_TABLES] pcs_values_shift = DynArray([]) # same structure, for next_mle-weighted column evals for i in unroll(0, N_TABLES): + pcs_points.push(DynArray([])) pcs_values.push(DynArray([])) pcs_values[i].push(DynArray([])) pcs_values_shift.push(DynArray([])) pcs_values_shift[i].push(DynArray([])) - total_num_cols = NUM_COLS_AIR[i] - for _ in unroll(0, total_num_cols): + for _ in unroll(0, NUM_COLS_AIR[i]): pcs_values[i][0].push(DynArray([])) pcs_values_shift[i][0].push(DynArray([])) for sorted_pos in unroll(0, N_TABLES): - table_index: Imu - if sorted_pos == 0: - table_index = EXECUTION_TABLE_INDEX - if sorted_pos == 1: - table_index = second_table - if sorted_pos == 2: - table_index = third_table - # I] Bus (data flow between tables) + table_index = sorted_table_index(sorted_pos, second_table, third_table) log_n_rows = table_log_heights[table_index] n_rows = table_heights[table_index] inner_point = point_gkr + (n_vars_logup_gkr - log_n_rows) * DIM pcs_points[table_index].push(inner_point) - if table_index == EXECUTION_TABLE_INDEX: - # 0] Bytecode lookup - bytecode_prefix = multilinear_location_prefix(offset / n_rows, n_vars_logup_gkr - log_n_rows, point_gkr) - - fs, eval_on_pc = fs_receive_ef_inlined(fs, 1) - pcs_values[EXECUTION_TABLE_INDEX][0][COL_PC].push(eval_on_pc) - fs, instr_evals = fs_receive_ef_inlined(fs, N_INSTRUCTION_COLUMNS) - for i in unroll(0, N_INSTRUCTION_COLUMNS): - global_index = N_COMMITTED_EXEC_COLUMNS + i - pcs_values[EXECUTION_TABLE_INDEX][0][global_index].push(instr_evals + i * DIM) - retrieved_numerators_value = add_extension_ret(retrieved_numerators_value, bytecode_prefix) - fingerp = fingerprint_bytecode(instr_evals, eval_on_pc, logup_alphas_eq_poly) - retrieved_denominators_value = add_extension_ret( - retrieved_denominators_value, - mul_extension_ret(bytecode_prefix, sub_extension_ret(logup_c, fingerp)), - ) - offset += n_rows - + # Bus (data flow between tables — Multiplicity::Column) prefix = multilinear_location_prefix(offset / n_rows, n_vars_logup_gkr - log_n_rows, point_gkr) fs, eval_on_selector = fs_receive_ef_inlined(fs, 1) @@ -339,33 +315,37 @@ def continue_recursion_ordered( offset += n_rows - # II] Lookup into memory - - for lookup_f_index in unroll(0, len(LOOKUPS_INDEXES[table_index])): - col_index = LOOKUPS_INDEXES[table_index][lookup_f_index] - fs, index_eval = fs_receive_ef_inlined(fs, 1) - debug_assert(len(pcs_values[table_index][0][col_index]) == 0) - pcs_values[table_index][0][col_index].push(index_eval) - for i in unroll(0, len(LOOKUPS_VALUES[table_index][lookup_f_index])): - fs, value_eval = fs_receive_ef_inlined(fs, 1) - col_index = LOOKUPS_VALUES[table_index][lookup_f_index][i] - debug_assert(len(pcs_values[table_index][0][col_index]) == 0) - pcs_values[table_index][0][col_index].push(value_eval) - - pref = multilinear_location_prefix(offset / n_rows, n_vars_logup_gkr - log_n_rows, point_gkr) # TODO there is some duplication here - retrieved_numerators_value = add_extension_ret(retrieved_numerators_value, pref) - fingerp = fingerprint_2( - LOGUP_MEMORY_DOMAINSEP, - value_eval, - add_base_extension_ret(i, index_eval), - logup_alphas_eq_poly, - ) - retrieved_denominators_value = add_extension_ret( - retrieved_denominators_value, - mul_extension_ret(pref, sub_extension_ret(logup_c, fingerp)), - ) - - offset += n_rows + # Multiplicity::One buses (bytecode lookup + memory lookups). + for one_bus_idx in unroll(0, len(ONE_BUSES_DOMSEPS[table_index])): + domsep = ONE_BUSES_DOMSEPS[table_index][one_bus_idx] + n_new = len(ONE_BUSES_NEW_COLS[table_index][one_bus_idx]) + n_data = len(ONE_BUSES_DATA_COLS[table_index][one_bus_idx]) + + fs, new_evals = fs_receive_ef_inlined(fs, n_new) + + for i in unroll(0, n_new): + new_col = ONE_BUSES_NEW_COLS[table_index][one_bus_idx][i] + debug_assert(len(pcs_values[table_index][0][new_col]) == 0) + pcs_values[table_index][0][new_col].push(new_evals + i * DIM) + + data_evals = Array(n_data * DIM) + for i in unroll(0, n_data): + data_col = ONE_BUSES_DATA_COLS[table_index][one_bus_idx][i] + data_ofs = ONE_BUSES_DATA_OFFSETS[table_index][one_bus_idx][i] + src = pcs_values[table_index][0][data_col][0] + if data_ofs == 0: + copy_5(src, data_evals + i * DIM) + if data_ofs != 0: + copy_5(add_base_extension_ret(data_ofs, src), data_evals + i * DIM) + + pref = multilinear_location_prefix(offset / n_rows, n_vars_logup_gkr - log_n_rows, point_gkr) + retrieved_numerators_value = add_extension_ret(retrieved_numerators_value, pref) + fingerp = fingerprint_n(domsep, data_evals, n_data, logup_alphas_eq_poly) + retrieved_denominators_value = add_extension_ret( + retrieved_denominators_value, + mul_extension_ret(pref, sub_extension_ret(logup_c, fingerp)), + ) + offset += n_rows retrieved_denominators_value = add_extension_ret( retrieved_denominators_value, @@ -391,13 +371,7 @@ def continue_recursion_ordered( initial_sum: Mut = ZERO_VEC_PTR for sorted_pos in unroll(0, N_TABLES): - table_index: Imu - if sorted_pos == 0: - table_index = EXECUTION_TABLE_INDEX - if sorted_pos == 1: - table_index = second_table - if sorted_pos == 2: - table_index = third_table + table_index = sorted_table_index(sorted_pos, second_table, third_table) bus_numerator_value = bus_numerators_values[sorted_pos] bus_denominator_value = bus_denominators_values[sorted_pos] @@ -406,7 +380,7 @@ def continue_recursion_ordered( bus_final_value = opposite_extension_ret(bus_final_value) bus_final_value = add_extension_ret( bus_final_value, - mul_extension_ret(bus_beta, sub_extension_ret(bus_denominator_value, logup_c)), + mul_extension_ret(bus_beta, sub_extension_ret(logup_c, bus_denominator_value)), ) initial_sum = add_extension_ret(initial_sum, mul_extension_ret(eta_powers + sorted_pos * DIM, bus_final_value)) @@ -416,13 +390,7 @@ def continue_recursion_ordered( check_sum: Mut = ZERO_VEC_PTR for sorted_pos in unroll(0, N_TABLES): - table_index: Imu - if sorted_pos == 0: - table_index = EXECUTION_TABLE_INDEX - if sorted_pos == 1: - table_index = second_table - if sorted_pos == 2: - table_index = third_table + table_index = sorted_table_index(sorted_pos, second_table, third_table) log_n_rows = table_log_heights[table_index] total_num_cols = NUM_COLS_AIR[table_index] n_flat_columns = N_AIR_COLUMNS[table_index] @@ -489,13 +457,7 @@ def continue_recursion_ordered( curr_randomness += DIM for sorted_pos in unroll(0, N_TABLES): - table_index: Imu - if sorted_pos == 0: - table_index = EXECUTION_TABLE_INDEX - if sorted_pos == 1: - table_index = second_table - if sorted_pos == 2: - table_index = third_table + table_index = sorted_table_index(sorted_pos, second_table, third_table) debug_assert(len(pcs_points[table_index]) == len(pcs_values[table_index])) for i in unroll(0, len(pcs_values[table_index])): # next_mle-weighted (shift) values come first @@ -603,13 +565,7 @@ def continue_recursion_ordered( curr_randomness += DIM for sorted_pos in unroll(0, N_TABLES): - table_index: Imu - if sorted_pos == 0: - table_index = EXECUTION_TABLE_INDEX - if sorted_pos == 1: - table_index = second_table - if sorted_pos == 2: - table_index = third_table + table_index = sorted_table_index(sorted_pos, second_table, third_table) log_n_rows = table_log_heights[table_index] n_rows = table_heights[table_index] total_num_cols = NUM_COLS_AIR[table_index] @@ -699,12 +655,23 @@ def fingerprint_2(table_index, data_1, data_2, logup_alphas_eq_poly): @inline -def fingerprint_bytecode(instr_evals, eval_on_pc, logup_alphas_eq_poly): - res: Mut = dot_product_ee_ret(instr_evals, logup_alphas_eq_poly, N_INSTRUCTION_COLUMNS) - res = add_extension_ret(res, mul_extension_ret(eval_on_pc, logup_alphas_eq_poly + N_INSTRUCTION_COLUMNS * DIM)) +def sorted_table_index(sorted_pos, second_table, third_table): + table_index: Imu + if sorted_pos == 0: + table_index = EXECUTION_TABLE_INDEX + if sorted_pos == 1: + table_index = second_table + if sorted_pos == 2: + table_index = third_table + return table_index + + +@inline +def fingerprint_n(domsep, data_evals, n, logup_alphas_eq_poly): + res: Mut = dot_product_ee_ret(data_evals, logup_alphas_eq_poly, n) res = add_extension_ret( res, - mul_base_extension_ret(LOGUP_BYTECODE_DOMAINSEP, logup_alphas_eq_poly + (2 ** log2_ceil(MAX_BUS_WIDTH) - 1) * DIM), + mul_base_extension_ret(domsep, logup_alphas_eq_poly + (2 ** log2_ceil(MAX_BUS_WIDTH) - 1) * DIM), ) return res @@ -791,14 +758,11 @@ def compute_stacked_n_vars(log_memory, log_bytecode_padded, tables_heights): def compute_total_gkr_n_vars(log_memory, log_bytecode_padded, tables_heights): total: Mut = two_exp(log_memory) total += two_exp(log_bytecode_padded) - total += tables_heights[EXECUTION_TABLE_INDEX] for table_index in unroll(0, N_TABLES): n_rows = tables_heights[table_index] - total_lookup_values: Mut = 0 - for i in unroll(0, len(LOOKUPS_INDEXES[table_index])): - total_lookup_values += len(LOOKUPS_VALUES[table_index][i]) - total_lookup_values += 1 # for the bus - total += n_rows * total_lookup_values + # +1 for the Multiplicity::Column bus, plus one block per Multiplicity::One bus. + n_buses = len(ONE_BUSES_DOMSEPS[table_index]) + 1 + total += n_rows * n_buses return log2_ceil_runtime(total) diff --git a/crates/sub_protocols/src/logup.rs b/crates/sub_protocols/src/logup.rs index 2c09fde8..341be5ea 100644 --- a/crates/sub_protocols/src/logup.rs +++ b/crates/sub_protocols/src/logup.rs @@ -93,8 +93,8 @@ pub fn prove_generic_logup( - finger_print_packed::( memory_domainsep_packed, &[ - PFPacking::::from_fn(|w| memory[src_idx(p, w)]), PFPacking::::from_fn(|w| F::from_usize(src_idx(p, w))), + PFPacking::::from_fn(|w| memory[src_idx(p, w)]), ], &alphas_packed, ) @@ -130,71 +130,94 @@ pub fn prove_generic_logup( for (table, _) in &tables_log_heights_sorted { let trace = &traces[table]; let log_n_rows = trace.log_n_rows; + let buses = table.bus_interactions(); + let mem_groups = memory_lookup_groups(&buses); + + let mut next_group = 0; + let mut bus_idx = 0; + while bus_idx < buses.len() { + if next_group < mem_groups.len() && mem_groups[next_group].start_bus == bus_idx { + let group = &mem_groups[next_group]; + let group_len = group.value_cols.len(); + let col_index = &trace.columns[group.idx_col]; + let packed_chunk_size = (1 << log_n_rows) / width; + + numerators[offset..][..group_len << log_n_rows] + .par_iter_mut() + .for_each(|n| *n = F::ONE); + + denominators[offset / width..][..group_len * packed_chunk_size] + .par_chunks_exact_mut(packed_chunk_size) + .enumerate() + .for_each(|(i, denom_chunk)| { + let i_field = F::from_usize(i); + let col_value = &trace.columns[group.value_cols[i]]; + denom_chunk.par_iter_mut().enumerate().for_each(|(p, slot)| { + *slot = c_packed + - finger_print_packed::( + memory_domainsep_packed, + &[ + PFPacking::::from_fn(|w| col_index[src_idx(p, w)] + i_field), + PFPacking::::from_fn(|w| col_value[src_idx(p, w)]), + ], + &alphas_packed, + ); + }); + }); + offset += group_len << log_n_rows; + bus_idx += group_len; + next_group += 1; + continue; + } - if *table == Table::execution() { - let pc_column = &trace.columns[COL_PC]; - let bytecode_columns = &trace.columns[N_RUNTIME_COLUMNS..][..N_INSTRUCTION_COLUMNS]; - numerators[offset..][..1 << log_n_rows] - .par_iter_mut() - .for_each(|n| *n = F::ONE); - fill_denoms(&mut denominators[offset / width..][..(1 << log_n_rows) / width], |p| { - let mut data = [PFPacking::::ZERO; N_INSTRUCTION_COLUMNS + 1]; - for k in 0..N_INSTRUCTION_COLUMNS { - data[k] = PFPacking::::from_fn(|w| bytecode_columns[k][src_idx(p, w)]); + let bus = &buses[bus_idx]; + let slice = &mut numerators[offset..][..1 << log_n_rows]; + match bus.multiplicity { + BusMultiplicity::One => { + let val = bus.direction.to_field_flag(); + slice.par_iter_mut().for_each(|n| *n = val); } - data[N_INSTRUCTION_COLUMNS] = PFPacking::::from_fn(|w| pc_column[src_idx(p, w)]); - c_packed - finger_print_packed::(bytecode_domainsep_packed, &data, &alphas_packed) + BusMultiplicity::Column(col) => { + fill_num_from(slice, &trace.columns[col], matches!(bus.direction, BusDirection::Pull)); + } + } + let denom_slot = &mut denominators[offset / width..][..(1 << log_n_rows) / width]; + + let n_data = bus.data.len(); + let mut data_cols: [&[F]; MAX_BUS_WIDTH] = [&[]; MAX_BUS_WIDTH]; + for (k, entry) in bus.data.iter().enumerate() { + match *entry { + BusData::Column(c) => { + data_cols[k] = &trace.columns[c]; + } + _ => { + panic!("Non-Column BusData::data entries are not supported on the fast path"); + } + } + } + let ds_col: Option<&[F]> = match bus.domainsep { + BusData::Column(c) => Some(&trace.columns[c]), + _ => None, + }; + let ds_constant_packed: PFPacking = match bus.domainsep { + BusData::Constant(v) => PFPacking::::from(F::from_usize(v)), + _ => PFPacking::::ZERO, + }; + + fill_denoms(denom_slot, |p| { + let mut data_buf = [PFPacking::::ZERO; MAX_BUS_WIDTH]; + for k in 0..n_data { + let col = data_cols[k]; + data_buf[k] = PFPacking::::from_fn(|w| col[src_idx(p, w)]); + } + let ds = match ds_col { + Some(col) => PFPacking::::from_fn(|w| col[src_idx(p, w)]), + None => ds_constant_packed, + }; + c_packed - finger_print_packed::(ds, &data_buf[..n_data], &alphas_packed) }); offset += 1 << log_n_rows; - } - - // I] Bus - let bus = table.bus(); - let multiplicity = &trace.columns[bus.multiplicity]; - let pull = matches!(bus.direction, BusDirection::Pull); - fill_num_from(&mut numerators[offset..][..1 << log_n_rows], multiplicity, pull); - let bus_data_entries = &bus.data; - let bus_domainsep = bus.domainsep; - let resolve = |entry: BusData, p: usize| match entry { - BusData::Column(col) => PFPacking::::from_fn(|w| trace.columns[col][src_idx(p, w)]), - BusData::Constant(val) => PFPacking::::from(F::from_usize(val)), - }; - fill_denoms(&mut denominators[offset / width..][..(1 << log_n_rows) / width], |p| { - let mut bus_data = [PFPacking::::ZERO; MAX_PRECOMPILE_BUS_WIDTH]; - for (j, entry) in bus_data_entries.iter().enumerate() { - bus_data[j] = resolve(*entry, p); - } - let domainsep = resolve(bus_domainsep, p); - c_packed + finger_print_packed::(domainsep, &bus_data[..bus_data_entries.len()], &alphas_packed) - }); - offset += 1 << log_n_rows; - - // II] Lookup into memory - let value_columns = table.lookup_value_columns(trace); - let index_columns = table.lookup_index_columns(trace); - for (col_index, col_values) in index_columns.iter().zip(&value_columns) { - numerators[offset..][..col_values.len() << log_n_rows] - .par_iter_mut() - .for_each(|n| *n = F::ONE); - let packed_chunk_size = (1 << log_n_rows) / width; - denominators[offset / width..][..col_values.len() * packed_chunk_size] - .par_chunks_exact_mut(packed_chunk_size) - .enumerate() - .for_each(|(i, denom_chunk)| { - let i_field = F::from_usize(i); - denom_chunk.par_iter_mut().enumerate().for_each(|(p, slot)| { - *slot = c_packed - - finger_print_packed::( - memory_domainsep_packed, - &[ - PFPacking::::from_fn(|w| col_values[i][src_idx(p, w)]), - PFPacking::::from_fn(|w| col_index[src_idx(p, w)] + i_field), - ], - &alphas_packed, - ); - }); - }); - offset += col_values.len() << log_n_rows; + bus_idx += 1; } } @@ -244,57 +267,48 @@ pub fn prove_generic_logup( let inner_point = MultilinearPoint(from_end(&claim_point_gkr, log_n_rows).to_vec()); let mut table_values = BTreeMap::::new(); - if table == &Table::execution() { - let pc_column = &trace.columns[COL_PC]; - let bytecode_columns = trace.columns[N_RUNTIME_COLUMNS..][..N_INSTRUCTION_COLUMNS] - .iter() - .collect::>(); - - let eval_on_pc = pc_column.evaluate(&inner_point); - prover_state.add_extension_scalar(eval_on_pc); - assert!(!table_values.contains_key(&COL_PC)); - table_values.insert(COL_PC, eval_on_pc); - - let instr_evals = bytecode_columns - .iter() - .map(|col| col.evaluate(&inner_point)) - .collect::>(); - prover_state.add_extension_scalars(&instr_evals); - for (i, eval_on_instr_col) in instr_evals.iter().enumerate() { - let global_index = N_RUNTIME_COLUMNS + i; - assert!(!table_values.contains_key(&global_index)); - table_values.insert(global_index, *eval_on_instr_col); + let resolve_ef = |entry: BusData| -> EF { + match entry { + BusData::Column(col) => trace.columns[col].evaluate(&inner_point), + BusData::ColumnPlusConstant(col, ofs) => trace.columns[col].evaluate(&inner_point) + F::from_usize(ofs), + BusData::Constant(val) => EF::from_usize(val), } - } - - let bus = table.bus(); - let eval_on_multiplicity = - trace.columns[bus.multiplicity].evaluate(&inner_point) * bus.direction.to_field_flag(); - prover_state.add_extension_scalar(eval_on_multiplicity); - - let resolve = |entry: BusData| match entry { - BusData::Column(col) => trace.columns[col].evaluate(&inner_point), - BusData::Constant(val) => EF::from_usize(val), }; - let bus_data_evals: Vec = bus.data.iter().map(|entry| resolve(*entry)).collect(); - let eval_on_data = c + finger_print(resolve(bus.domainsep), &bus_data_evals, alphas_eq_poly); - prover_state.add_extension_scalar(eval_on_data); - - bus_numerators_values.insert(*table, eval_on_multiplicity); - bus_denominators_values.insert(*table, eval_on_data); - - // II] Lookup into memory - for lookup in table.lookups() { - let index_eval = trace.columns[lookup.index].evaluate(&inner_point); - prover_state.add_extension_scalar(index_eval); - assert!(!table_values.contains_key(&lookup.index)); - table_values.insert(lookup.index, index_eval); - - for col_index in &lookup.values { - let value_eval = trace.columns[*col_index].evaluate(&inner_point); - prover_state.add_extension_scalar(value_eval); - assert!(!table_values.contains_key(col_index)); - table_values.insert(*col_index, value_eval); + + for bus in table.bus_interactions() { + match bus.multiplicity { + BusMultiplicity::Column(mult_col) => { + let eval_on_multiplicity = + trace.columns[mult_col].evaluate(&inner_point) * bus.direction.to_field_flag(); + prover_state.add_extension_scalar(eval_on_multiplicity); + let data_evals: Vec = bus.data.iter().map(|e| resolve_ef(*e)).collect(); + let eval_on_data = c - finger_print(resolve_ef(bus.domainsep), &data_evals, alphas_eq_poly); + prover_state.add_extension_scalar(eval_on_data); + bus_numerators_values.insert(*table, eval_on_multiplicity); + bus_denominators_values.insert(*table, eval_on_data); + } + BusMultiplicity::One => { + // Skip columns already in table_values: memory-lookup groups share + // an idx column across buses, so it's written once per group rather + // than once per bus. This also keeps simple-lookup writes (e.g. the + // bytecode bus) batched into a single RATE-aligned transcript block. + let col_evals: Vec = bus + .data + .iter() + .filter_map(|entry| { + entry.column().and_then(|col| { + if let std::collections::btree_map::Entry::Vacant(e) = table_values.entry(col) { + let v = trace.columns[col].evaluate(&inner_point); + e.insert(v); + Some(v) + } else { + None + } + }) + }) + .collect(); + prover_state.add_extension_scalars(&col_evals); + } } } @@ -356,7 +370,7 @@ pub fn verify_generic_logup( retrieved_denominators_value += pref * (c - finger_print( EF::from_usize(LOGUP_MEMORY_DOMAINSEP), - &[value_memory, value_index], + &[value_index, value_memory], alphas_eq_poly, )); let mut offset = 1 << log_memory; @@ -397,62 +411,54 @@ pub fn verify_generic_logup( for &(table, log_n_rows) in &tables_heights_sorted { let mut table_values = BTreeMap::::new(); - if table == Table::execution() { - // 0] bytecode lookup - let eval_on_pc = verifier_state.next_extension_scalar()?; - table_values.insert(COL_PC, eval_on_pc); - - let instr_evals = verifier_state.next_extension_scalars_vec(N_INSTRUCTION_COLUMNS)?; - for (i, eval_on_instr_col) in instr_evals.iter().enumerate() { - table_values.insert(N_RUNTIME_COLUMNS + i, *eval_on_instr_col); - } - + for bus in table.bus_interactions() { let pref = pref_at(offset, log_n_rows); - retrieved_numerators_value += pref; // numerator is 1 - retrieved_denominators_value += pref - * (c - finger_print( - EF::from_usize(LOGUP_BYTECODE_DOMAINSEP), - &[instr_evals, vec![eval_on_pc]].concat(), - alphas_eq_poly, - )); - - offset += 1 << log_n_rows; - } - - // I] Bus (data flow between tables) - let eval_on_multiplicity = verifier_state.next_extension_scalar()?; - let pref = pref_at(offset, log_n_rows); - retrieved_numerators_value += pref * eval_on_multiplicity; - - let eval_on_data = verifier_state.next_extension_scalar()?; - retrieved_denominators_value += pref * eval_on_data; - - bus_numerators_values.insert(table, eval_on_multiplicity); - bus_denominators_values.insert(table, eval_on_data); - - offset += 1 << log_n_rows; - - // II] Lookup into memory - for lookup in table.lookups() { - let index_eval = verifier_state.next_extension_scalar()?; - assert!(!table_values.contains_key(&lookup.index)); - table_values.insert(lookup.index, index_eval); - - for (i, col_index) in lookup.values.iter().enumerate() { - let value_eval = verifier_state.next_extension_scalar()?; - assert!(!table_values.contains_key(col_index)); - table_values.insert(*col_index, value_eval); - - let pref = pref_at(offset, log_n_rows); - retrieved_numerators_value += pref; // numerator is 1 - retrieved_denominators_value += pref - * (c - finger_print( - EF::from_usize(LOGUP_MEMORY_DOMAINSEP), - &[value_eval, index_eval + F::from_usize(i)], - alphas_eq_poly, - )); - offset += 1 << log_n_rows; + match bus.multiplicity { + BusMultiplicity::Column(_) => { + let eval_on_multiplicity = verifier_state.next_extension_scalar()?; + let eval_on_data = verifier_state.next_extension_scalar()?; + retrieved_numerators_value += pref * eval_on_multiplicity; + retrieved_denominators_value += pref * eval_on_data; + bus_numerators_values.insert(table, eval_on_multiplicity); + bus_denominators_values.insert(table, eval_on_data); + } + BusMultiplicity::One => { + let n_col_entries = bus + .data + .iter() + .filter(|e| e.column().is_some_and(|col| !table_values.contains_key(&col))) + .count(); + let col_evals = verifier_state.next_extension_scalars_vec(n_col_entries)?; + let mut eval_iter = col_evals.into_iter(); + let data_evals: Vec = bus + .data + .iter() + .map(|entry| match *entry { + BusData::Constant(val) => EF::from_usize(val), + BusData::Column(col) | BusData::ColumnPlusConstant(col, _) => { + let v = if let Some(&cached) = table_values.get(&col) { + cached + } else { + let v = eval_iter.next().unwrap(); + table_values.insert(col, v); + v + }; + match *entry { + BusData::ColumnPlusConstant(_, ofs) => v + F::from_usize(ofs), + _ => v, + } + } + }) + .collect(); + let BusData::Constant(domainsep) = bus.domainsep else { + unreachable!("multiplicity-One bus domsep must be a constant"); + }; + retrieved_numerators_value += pref * bus.direction.to_field_flag(); + retrieved_denominators_value += + pref * (c - finger_print(EF::from_usize(domainsep), &data_evals, alphas_eq_poly)); + } } + offset += 1 << log_n_rows; } columns_values.insert(table, table_values); @@ -483,8 +489,7 @@ pub fn verify_generic_logup( } fn offset_for_table(table: &Table, log_n_rows: usize) -> usize { - let num_cols = table.lookups().iter().map(|l| l.values.len()).sum::() + 1; // +1 for the bus - num_cols << log_n_rows + table.bus_interactions().len() << log_n_rows } pub fn compute_total_logup_log_size( @@ -504,14 +509,8 @@ fn compute_total_active_len( tables_heights_sorted: &[(Table, VarCount)], ) -> usize { let max_table_height = 1 << tables_heights_sorted[0].1; - let log_n_cycles = tables_heights_sorted - .iter() - .find(|(table, _)| *table == Table::execution()) - .unwrap() - .1; (1 << log_memory) + (1 << log_bytecode).max(max_table_height) - + (1 << log_n_cycles) + tables_heights_sorted .iter() .map(|(table, log_n_rows)| offset_for_table(table, *log_n_rows)) diff --git a/crates/sub_protocols/src/stacked_pcs.rs b/crates/sub_protocols/src/stacked_pcs.rs index fd7760bf..7ba0b515 100644 --- a/crates/sub_protocols/src/stacked_pcs.rs +++ b/crates/sub_protocols/src/stacked_pcs.rs @@ -1,7 +1,7 @@ use backend::*; use lean_vm::{ - ALL_TABLES, COL_PC, CommittedStatements, MIN_LOG_MEMORY_SIZE, MIN_LOG_N_ROWS_PER_TABLE, N_INSTRUCTION_COLUMNS, - STARTING_PC, sort_tables_by_height, + ALL_TABLES, COL_PC, ColIndex, CommittedStatements, MIN_LOG_MEMORY_SIZE, MIN_LOG_N_ROWS_PER_TABLE, + N_INSTRUCTION_COLUMNS, STARTING_PC, sort_tables_by_height, }; use lean_vm::{EF, F, Table, TableT, TableTrace}; use std::collections::BTreeMap; @@ -208,11 +208,15 @@ pub fn total_whir_statements() -> usize { + ALL_TABLES .iter() .map(|table| { - // AIR - table.n_columns() - + table.n_shift_columns() - // Lookups into memory - + table.lookups().iter().map(|lookup| 1 + lookup.values.len()).sum::() + let mut seen_cols = std::collections::HashSet::::new(); + for bus in table.bus_interactions().iter().filter(|b| b.is_memory_lookup()) { + for entry in &bus.data { + if let Some(col) = entry.column() { + seen_cols.insert(col); + } + } + } + table.n_columns() + table.n_shift_columns() + seen_cols.len() }) .sum::() // bytecode lookup diff --git a/crates/sub_protocols/tests/soundness_logup.rs b/crates/sub_protocols/tests/soundness_logup.rs index 400ce13f..bc71eec4 100644 --- a/crates/sub_protocols/tests/soundness_logup.rs +++ b/crates/sub_protocols/tests/soundness_logup.rs @@ -1,8 +1,7 @@ use backend::{Field, log2_ceil_usize}; use lean_prover::SECURITY_BITS; use lean_vm::{ - EF, MAX_BYTECODE_LOG_SIZE, MAX_LOG_MEMORY_SIZE, MAX_LOG_N_ROWS_PER_TABLE, max_bus_width_including_bytecode, - sort_tables_by_height, + EF, LOG_MAX_BUS_WIDTH, MAX_BYTECODE_LOG_SIZE, MAX_LOG_MEMORY_SIZE, MAX_LOG_N_ROWS_PER_TABLE, sort_tables_by_height, }; use std::collections::BTreeMap; use sub_protocols::compute_total_logup_log_size; @@ -15,6 +14,6 @@ fn ensure_logup_soundness_is_suffisant() { &sort_tables_by_height(&BTreeMap::from(MAX_LOG_N_ROWS_PER_TABLE)), ); // TODO explain formula - let logup_error_bits = max_logup_n_vars + log2_ceil_usize(log2_ceil_usize(max_bus_width_including_bytecode())); + let logup_error_bits = max_logup_n_vars + log2_ceil_usize(LOG_MAX_BUS_WIDTH); assert!(SECURITY_BITS + logup_error_bits <= EF::bits()); }