Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 17 additions & 22 deletions src/lookup.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use p3_air::{Air, BaseAir, ExtensionBuilder, WindowAccess};
use p3_field::{PrimeCharacteristicRing, batch_multiplicative_inverse};
use p3_matrix::{Matrix, dense::RowMajorMatrix};
use p3_maybe_rayon::prelude::*;

use crate::{
builder::{TwoStagedBuilder, symbolic::SymbolicExpression},
Expand Down Expand Up @@ -113,29 +114,23 @@ impl Lookup<Val> {
fingerprint_challenge: &ExtVal,
mut accumulator: ExtVal,
) -> (Vec<RowMajorMatrix<ExtVal>>, Vec<ExtVal>) {
// Collect the number of lookups per circuit while accumulating the total
// number of lookups.
let mut num_lookups_per_circuit = Vec::with_capacity(lookups.len());
let mut total_num_lookups = 0;
for circuit_lookups in lookups {
let num_rows = circuit_lookups.len();
// Every row is assumed to have the same number of lookups, which is
// the number of lookups of the first row.
let num_row_lookups = circuit_lookups[0].len();
let num_circuit_lookups = num_rows * num_row_lookups;
num_lookups_per_circuit.push(num_circuit_lookups);
total_num_lookups += num_circuit_lookups;
}
// Number of lookups per circuit. Every row in a circuit is assumed to
// have the same number of lookups (the lookups are expected to be fully
// padded), so this is taken from the first row.
let num_lookups_per_circuit: Vec<usize> = lookups
.iter()
.map(|circuit_lookups| circuit_lookups.len() * circuit_lookups[0].len())
.collect();

// Compute and collect all messages. There's one message per lookup.
let mut messages = Vec::with_capacity(total_num_lookups);
for circuit_lookups in lookups {
let circuit_messages = circuit_lookups
.iter()
.flatten()
.map(|lookup| lookup.compute_message(lookup_challenge, fingerprint_challenge));
messages.extend(circuit_messages);
}
// Compute the message for each lookup, in flat circuit-major order.
// Flatten the references serially first so the parallel map operates
// on an indexed slice and `collect` can write straight into the
// output Vec without tree-reducing worker buffers.
let flat: Vec<&Self> = lookups.iter().flatten().flatten().collect();
let messages: Vec<ExtVal> = flat
.par_iter()
.map(|lookup| lookup.compute_message(lookup_challenge, fingerprint_challenge))
.collect();

// Compute the inverses of all messages in batch.
let messages_inverses = batch_multiplicative_inverse(&messages);
Expand Down
Loading