diff --git a/crates/whir/src/dft.rs b/crates/whir/src/dft.rs index 277597eb8..8f2901c2e 100644 --- a/crates/whir/src/dft.rs +++ b/crates/whir/src/dft.rs @@ -62,13 +62,38 @@ impl EvalsDft { } pub(crate) fn update_twiddles(&self, fft_len: usize) { - // TODO: This recomputes the entire table from scratch if we - // need it to be larger, which is wasteful. let mut guard = self.twiddles.write().unwrap(); - let curr_max_fft_len = 1 << guard.len(); - if fft_len > curr_max_fft_len { + + let lg_n = log2_strict_usize(fft_len); + + //if the current size is already big enough we don't do anything + if lg_n <= guard.len() { + return; + } + + //if current twiddles is empty we compute from nothing + if guard.is_empty() { *guard = self.roots_of_unity_table(fft_len); + return; } + + let diff_log = lg_n - guard.len(); + let nb_steps = 1 << diff_log; //number of missing points between each preexisting points + + let generator = F::two_adic_generator(lg_n); + + let table = guard[0].clone(); + let mut nth_root = Vec::with_capacity(table.len() * nb_steps); + for &base in &table { + nth_root.extend(generator.shifted_powers(base).take(nb_steps)); + } + + let old_twiddles = std::mem::take(&mut *guard); + let mut twiddles = Vec::with_capacity(diff_log + old_twiddles.len()); + twiddles.extend((0..diff_log).map(|i| nth_root.iter().step_by(1 << i).copied().collect::>())); + twiddles.extend(old_twiddles); + + *guard = twiddles; } }