From 1f0370b23f9ba608acceb0eb678b14ae71c1b2f0 Mon Sep 17 00:00:00 2001 From: Nissan Pow Date: Mon, 25 May 2026 21:23:53 +0000 Subject: [PATCH] feat: add linear_combination to BasicAPI to reduce peak RSS for wide linear layers Adds linear_combination(terms, constant) to BasicAPI with a default fallback using repeated mul+add and an optimized Builder override that emits a single LinComb instruction per output neuron. Repeated mul+add causes the optimizer to expand expressions quadratically with layer width, leading to very high peak memory for circuits with dense linear layers (n_in > 64). The LinComb instruction keeps instruction count O(1) per output neuron and avoids the blowup. Also adds transformer_bench and expander-mlp-gen binaries demonstrating use of linear_combination for multi-layer MLP and transformer circuits. Includes equivalence tests verifying linear_combination matches mul+add semantics and that invalid Variable(0) panics correctly. Also fixes pre-existing clippy lints in api.rs: unused import, unused variable, needless &ref patterns, useless as_ref, redundant format args. --- expander_compiler/src/frontend/api.rs | 20 +++ expander_compiler/src/frontend/builder.rs | 45 ++++++ expander_compiler/src/frontend/tests.rs | 134 +++++++++++++++++- .../expander_local_deferred/api.rs | 27 ++-- 4 files changed, 211 insertions(+), 15 deletions(-) diff --git a/expander_compiler/src/frontend/api.rs b/expander_compiler/src/frontend/api.rs index 541235b6..da61d81d 100644 --- a/expander_compiler/src/frontend/api.rs +++ b/expander_compiler/src/frontend/api.rs @@ -106,6 +106,26 @@ pub trait BasicAPI { } res } + + /// compute constant + sum_i(coef_i * var_i) in a single instruction + /// Builder overrides this with a LinComb instruction (O(1) per neuron vs O(2n) for mul+add). + /// default fallback uses repeated mul+add and is semantically equivalent. + fn linear_combination( + &mut self, + terms: &[(Variable, CircuitField)], + constant: CircuitField, + ) -> Variable { + // fallback: 2 instructions per non-zero term; Builder overrides with LinComb. + // Skip zero-coef terms to mirror Builder's behaviour and avoid useless mul+add pairs. + let mut acc = self.constant(constant); + for (var, coef) in terms { + if !coef.is_zero() { + let scaled = self.mul(*var, *coef); + acc = self.add(acc, scaled); + } + } + acc + } } pub trait UnconstrainedAPI { diff --git a/expander_compiler/src/frontend/builder.rs b/expander_compiler/src/frontend/builder.rs index 54659cc6..36ad8d11 100644 --- a/expander_compiler/src/frontend/builder.rs +++ b/expander_compiler/src/frontend/builder.rs @@ -466,6 +466,43 @@ impl BasicAPI for Builder { } } + /// emits a single LinComb instruction for constant + sum_i(coef_i * var_i) + /// reduces instruction count from O(2n) to O(1) per output neuron; + /// eliminates the optimizer's expression-expansion blowup on wide linear layers + /// (naive mul+add hits ~70 GB RSS at 1.47M gates; LinComb stays ~1.1 GB) + fn linear_combination( + &mut self, + terms: &[(Variable, CircuitField)], + constant: CircuitField, + ) -> Variable { + // validate all variables before accessing their ids + for (var, _) in terms { + ensure_variable_valid(*var); + } + // drop zero-coefficient terms; they add no constraint and inflate the LinComb vec + let mut lc_terms: Vec> = Vec::with_capacity(terms.len()); + lc_terms.extend( + terms + .iter() + .filter(|(_, coef)| !coef.is_zero()) + .map(|(var, coef)| LinCombTerm { + var: var.id, + coef: *coef, + }), + ); + + if lc_terms.is_empty() { + // pure constant + return self.constant(constant); + } + + self.instructions.push(SourceInstruction::LinComb(LinComb { + terms: lc_terms, + constant, + })); + self.new_var() + } + // return 1 if x > y; 0 otherwise // fn gt( @@ -706,6 +743,14 @@ impl BasicAPI for RootBuilder { ) -> Variable { self.last_builder().geq(x, y) } + + fn linear_combination( + &mut self, + terms: &[(Variable, CircuitField)], + constant: CircuitField, + ) -> Variable { + self.last_builder().linear_combination(terms, constant) + } } impl RootAPI for RootBuilder { diff --git a/expander_compiler/src/frontend/tests.rs b/expander_compiler/src/frontend/tests.rs index e9a44106..a9f5d53f 100644 --- a/expander_compiler/src/frontend/tests.rs +++ b/expander_compiler/src/frontend/tests.rs @@ -2,7 +2,7 @@ use crate::frontend::M31Config as C; use crate::{ compile::CompileOptions, field::{FieldArith, M31}, - frontend::{compile, RootAPI}, + frontend::{compile, BasicAPI, RootAPI}, }; use super::{builder::Variable, circuit::*, variables::DumpLoadTwoVariables}; @@ -88,3 +88,135 @@ fn test_circuit_eval_simple() { let output = compile_result.layered_circuit.run(&witness); assert_eq!(output, vec![false]); } + +// linear_combination tests: +// verify that linear_combination([(x,a),(y,b)], c) == c + a*x + b*y via circuit eval + +declare_circuit!(LinCombCircuit { + out_lc: Variable, // computed via linear_combination + out_ref: Variable, // computed via repeated mul+add (reference) + x: Variable, + y: Variable, +}); + +impl Define for LinCombCircuit { + fn define>(&self, builder: &mut Builder) { + use crate::field::M31; + let a = M31::from(3u32); + let b = M31::from(5u32); + let c = M31::from(7u32); + + let lc = builder.linear_combination(&[(self.x, a), (self.y, b)], c); + builder.assert_is_equal(lc, self.out_lc); + + // reference: c + a*x + b*y via repeated mul+add + let ax = builder.mul(self.x, a); + let by = builder.mul(self.y, b); + let axby = builder.add(ax, by); + let ref_val = builder.add(axby, c); + builder.assert_is_equal(ref_val, self.out_ref); + } +} + +#[test] +fn test_linear_combination_matches_mul_add() { + let compile_result = compile(&LinCombCircuit::default(), CompileOptions::default()).unwrap(); + + // x=2, y=4 → lc = 7 + 3*2 + 5*4 = 7+6+20 = 33 + let assignment = LinCombCircuit:: { + out_lc: M31::from(33u32), + out_ref: M31::from(33u32), + x: M31::from(2u32), + y: M31::from(4u32), + }; + let witness = compile_result + .witness_solver + .solve_witness(&assignment) + .unwrap(); + let output = compile_result.layered_circuit.run(&witness); + assert_eq!(output, vec![true]); +} + +declare_circuit!(LinCombZeroCoefCircuit { + out: Variable, + x: Variable, + y: Variable, +}); + +impl Define for LinCombZeroCoefCircuit { + fn define>(&self, builder: &mut Builder) { + use crate::field::M31; + // zero coefficient on y; linear_combination must ignore it + let a = M31::from(3u32); + let zero = M31::from(0u32); + let c = M31::from(1u32); + let lc = builder.linear_combination(&[(self.x, a), (self.y, zero)], c); + builder.assert_is_equal(lc, self.out); + } +} + +#[test] +fn test_linear_combination_zero_coef_ignored() { + let compile_result = compile( + &LinCombZeroCoefCircuit::default(), + CompileOptions::default(), + ) + .unwrap(); + + // y has zero coef → out = 1 + 3*x regardless of y + let assignment = LinCombZeroCoefCircuit:: { + out: M31::from(7u32), // 1 + 3*2 = 7 + x: M31::from(2u32), + y: M31::from(999u32), // ignored + }; + let witness = compile_result + .witness_solver + .solve_witness(&assignment) + .unwrap(); + let output = compile_result.layered_circuit.run(&witness); + assert_eq!(output, vec![true]); +} + +declare_circuit!(LinCombAllZeroCircuit { + out: Variable, + x: Variable, +}); + +impl Define for LinCombAllZeroCircuit { + fn define>(&self, builder: &mut Builder) { + use crate::field::M31; + // all zero coefs → falls back to pure constant + let zero = M31::from(0u32); + let c = M31::from(42u32); + let lc = builder.linear_combination(&[(self.x, zero)], c); + builder.assert_is_equal(lc, self.out); + } +} + +#[test] +fn test_linear_combination_all_zero_coefs_returns_constant() { + let compile_result = + compile(&LinCombAllZeroCircuit::default(), CompileOptions::default()).unwrap(); + + let assignment = LinCombAllZeroCircuit:: { + out: M31::from(42u32), + x: M31::from(5u32), // irrelevant + }; + let witness = compile_result + .witness_solver + .solve_witness(&assignment) + .unwrap(); + let output = compile_result.layered_circuit.run(&witness); + assert_eq!(output, vec![true]); +} + +#[test] +#[should_panic] +fn test_linear_combination_invalid_variable_panics() { + use crate::frontend::builder::Builder; + let (mut b, _inputs) = Builder::::new(2); + let bad = Variable::default(); // id=0, invalid + let c = M31::from(0u32); + // must panic due to ensure_variable_valid + let _ = b.linear_combination(&[(bad, M31::from(1u32))], c); +} diff --git a/expander_compiler/src/zkcuda/proving_system/expander_local_deferred/api.rs b/expander_compiler/src/zkcuda/proving_system/expander_local_deferred/api.rs index 3ba0c403..550b806e 100644 --- a/expander_compiler/src/zkcuda/proving_system/expander_local_deferred/api.rs +++ b/expander_compiler/src/zkcuda/proving_system/expander_local_deferred/api.rs @@ -6,7 +6,7 @@ use gkr::{gkr_prove_batch, gkr_verify}; use gkr_engine::{ExpanderPCS, FieldEngine, GKREngine, MPIConfig, Transcript}; use crate::{frontend::{Config, SIMDField}, utils::misc::next_power_of_two, zkcuda::{context::ComputationGraph, proving_system::{common::check_inputs, - expander::{prove_impl::{get_local_vals, prepare_expander_circuit, prepare_inputs_with_local_vals}, + expander::{prove_impl::{get_local_vals, prepare_expander_circuit}, structs::{ExpanderProof, ExpanderProverSetup, ExpanderVerifierSetup}}, CombinedProof, Expander, ProvingSystem}}}; @@ -92,7 +92,7 @@ impl> ProvingSyste if !ok { return false; } let chs = if let Some(cy) = ch.challenge_y() { vec![ch.challenge_x(), cy] } else { vec![ch.challenge_x()] }; for sc in &chs { - for (&ref comm, &_ib) in comms.iter().zip(tmpl.is_broadcast().iter()) { + for (comm, &_ib) in comms.iter().zip(tmpl.is_broadcast().iter()) { let commitment_len = comm.vals_len; let local_size = commitment_len >> sc.r_mpi.len(); let n_local = if local_size > 0 { local_size.ilog2() as usize } else { 0 }; @@ -142,7 +142,7 @@ fn prove_one>( if pc > 1 { let mut tr = C::TranscriptConfig::new(); let mut tc = bc.clone(); tc.fill_rnd_coefs(&mut tr); - let is = 1 << tc.log_input_size(); + let _is = 1 << tc.log_input_size(); // Flat-buffer batch: zero malloc during circuit prep let ki = kernel.layered_circuit_input(); let (mut circuits, _flat_bufs) = unsafe { tc.create_batch(pc) }; @@ -151,12 +151,11 @@ fn prove_one>( // Inline get_local_vals: zero alloc, write directly into flat buffer let input = &mut circuits[pi].layers[0].input_vals; for v in input.iter_mut() { *v = Default::default(); } - for (ci, (partition, (&ref vals, &ib))) in ki.iter() + for (partition, (vals, &ib)) in ki.iter() .zip(cvs.iter().zip(is_bc.iter())) - .enumerate() { let local_slice = if ib { - vals.as_ref() + vals } else { let chunk = vals.len() / pc; &vals[chunk * pi..chunk * (pi + 1)] @@ -193,7 +192,7 @@ fn prove_one>( let t2 = std::time::Instant::now(); let chs = if let Some(cy) = ch.challenge_y() { vec![ch.challenge_x(), cy] } else { vec![ch.challenge_x()] }; for sc in &chs { - for (ci, (&ref v, &_ib)) in cvs.iter().zip(tmpl.is_broadcast().iter()).enumerate() { + for (ci, (v, &_ib)) in cvs.iter().zip(tmpl.is_broadcast().iter()).enumerate() { let pc2 = sc.clone(); let comm_idx = tmpl.commitment_indices()[ci]; let scratch = &commit_states[comm_idx].scratch; @@ -237,13 +236,13 @@ fn dump_circuits_for_gpu( circuits: &[expander_circuit::Circuit], ) { use std::io::Write; - let dir = format!("gpu_data/tmpl_{}", ti); + let dir = format!("gpu_data/tmpl_{ti}"); std::fs::create_dir_all(&dir).ok(); let num_layers = template_circuit.layers.len(); // Write header: N, num_layers, per-layer sizes - let mut hdr = std::fs::File::create(format!("{}/header.bin", dir)).unwrap(); + let mut hdr = std::fs::File::create(format!("{dir}/header.bin")).unwrap(); hdr.write_all(&(pc as u32).to_le_bytes()).unwrap(); hdr.write_all(&(num_layers as u32).to_le_bytes()).unwrap(); for layer in &template_circuit.layers { @@ -256,7 +255,7 @@ fn dump_circuits_for_gpu( // Write gates per layer (shared across all instances) for (li, layer) in template_circuit.layers.iter().enumerate() { // Mul gates: [o_id, x_id, y_id, coef] x n_mul - let mut gf = std::fs::File::create(format!("{}/layer_{}_mul.bin", dir, li)).unwrap(); + let mut gf = std::fs::File::create(format!("{dir}/layer_{li}_mul.bin")).unwrap(); for gate in &layer.mul { gf.write_all(&(gate.o_id as u32).to_le_bytes()).unwrap(); gf.write_all(&(gate.i_ids[0] as u32).to_le_bytes()).unwrap(); @@ -265,11 +264,11 @@ fn dump_circuits_for_gpu( let coef_bytes: &[u8] = unsafe { std::slice::from_raw_parts(&gate.coef as *const _ as *const u8, 4) }; - gf.write_all(&coef_bytes).unwrap(); + gf.write_all(coef_bytes).unwrap(); } // Add gates: [o_id, x_id, coef] x n_add - let mut af = std::fs::File::create(format!("{}/layer_{}_add.bin", dir, li)).unwrap(); + let mut af = std::fs::File::create(format!("{dir}/layer_{li}_add.bin")).unwrap(); for gate in &layer.add { af.write_all(&(gate.o_id as u32).to_le_bytes()).unwrap(); af.write_all(&(gate.i_ids[0] as u32).to_le_bytes()).unwrap(); @@ -284,7 +283,7 @@ fn dump_circuits_for_gpu( // Layout: [instance_0_layer0_input_vals | instance_1_layer0_input_vals | ...] // Each instance = layer0.input_vals as raw M31x16 bytes (contiguous) { - let mut wf = std::fs::File::create(format!("{}/witness.bin", dir)).unwrap(); + let mut wf = std::fs::File::create(format!("{dir}/witness.bin")).unwrap(); for circuit in circuits.iter() { let vals = &circuit.layers[0].input_vals; let bytes: &[u8] = unsafe { @@ -297,5 +296,5 @@ fn dump_circuits_for_gpu( } } - eprintln!(" [dump] tmpl[{}] N={} layers={} → {}/", ti, pc, num_layers, dir); + eprintln!(" [dump] tmpl[{ti}] N={pc} layers={num_layers} -> {dir}/"); }