diff --git a/programs/system/src/context.rs b/programs/system/src/context.rs index c26fa6417d..1ff0ff2a58 100644 --- a/programs/system/src/context.rs +++ b/programs/system/src/context.rs @@ -147,15 +147,40 @@ impl<'info> SystemContext<'info> { } } - pub fn set_rollover_fee(&mut self, ix_data_index: u8, fee: u64) { + pub fn set_rollover_fee( + &mut self, + ix_data_index: u8, + fee: u64, + ) -> std::result::Result<(), ProgramError> { let payment = self .rollover_fee_payments .iter_mut() .find(|a| a.0 == ix_data_index); match payment { - Some(payment) => payment.1 += fee, + Some(payment) => { + payment.1 = payment + .1 + .checked_add(fee) + .ok_or(ProgramError::ArithmeticOverflow)?; + } None => self.rollover_fee_payments.push((ix_data_index, fee)), }; + Ok(()) + } + + #[cfg(test)] + fn new_for_test() -> SystemContext<'static> { + SystemContext { + account_indices: Vec::new(), + accounts: Vec::new(), + account_infos: Vec::new(), + hashed_pubkeys: Vec::new(), + addresses: Vec::new(), + rollover_fee_payments: Vec::new(), + network_fee_is_set: false, + legacy_merkle_context: Vec::new(), + invoking_program_id: None, + } } /// Network fee distribution (fees read from tree metadata as `network_fee`): @@ -182,6 +207,22 @@ impl<'info> SystemContext<'info> { } } +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn set_rollover_fee_accumulates_checked() { + let mut context = SystemContext::new_for_test(); + + context.set_rollover_fee(3, u64::MAX).unwrap(); + let err = context.set_rollover_fee(3, 1).unwrap_err(); + + assert_eq!(err, ProgramError::ArithmeticOverflow); + assert_eq!(context.rollover_fee_payments, vec![(3, u64::MAX)]); + } +} + #[derive(Debug)] pub struct WrappedInstructionData<'a, T: InstructionData<'a>> { instruction_data: T, diff --git a/programs/system/src/processor/create_address_cpi_data.rs b/programs/system/src/processor/create_address_cpi_data.rs index aa0a6d2885..7f028cabf8 100644 --- a/programs/system/src/processor/create_address_cpi_data.rs +++ b/programs/system/src/processor/create_address_cpi_data.rs @@ -123,7 +123,7 @@ pub fn derive_new_addresses<'info, 'a, 'b: 'a, const ADDRESS_ASSIGNMENT: bool>( } cpi_ix_data.addresses[i].address = address; - context.set_rollover_fee(new_address_params.address_queue_index(), rollover_fee); + context.set_rollover_fee(new_address_params.address_queue_index(), rollover_fee)?; } cpi_ix_data.num_address_queues = accounts .iter() diff --git a/programs/system/src/processor/create_outputs_cpi_data.rs b/programs/system/src/processor/create_outputs_cpi_data.rs index aa81653561..d7c94f2f39 100644 --- a/programs/system/src/processor/create_outputs_cpi_data.rs +++ b/programs/system/src/processor/create_outputs_cpi_data.rs @@ -221,7 +221,7 @@ pub fn create_outputs_cpi_data<'a, 'info, T: InstructionData<'a>>( .map_err(ProgramError::from)?; } } - context.set_rollover_fee(current_index as u8, rollover_fee); + context.set_rollover_fee(current_index as u8, rollover_fee)?; } cpi_ix_data.num_output_queues = index_merkle_tree_account as u8;