diff --git a/src/rust/src/backend/keys.rs b/src/rust/src/backend/keys.rs index 5acebea690b1..4aee50151903 100644 --- a/src/rust/src/backend/keys.rs +++ b/src/rust/src/backend/keys.rs @@ -333,7 +333,7 @@ fn public_key_from_pkey<'p>( // `id` is a separate argument so we can test this while passing something // unsupported. match id { - openssl::pkey::Id::RSA => Ok(crate::backend::rsa::public_key_from_pkey(pkey) + openssl::pkey::Id::RSA => Ok(crate::backend::rsa::public_key_from_pkey(pkey)? .into_pyobject(py)? .into_any()), openssl::pkey::Id::EC => Ok(crate::backend::ec::public_key_from_pkey(py, pkey)? diff --git a/src/rust/src/backend/rsa.rs b/src/rust/src/backend/rsa.rs index aa456508a2b6..7a0a1dff0123 100644 --- a/src/rust/src/backend/rsa.rs +++ b/src/rust/src/backend/rsa.rs @@ -56,10 +56,12 @@ pub(crate) fn private_key_from_pkey( pub(crate) fn public_key_from_pkey( pkey: &openssl::pkey::PKeyRef, -) -> RsaPublicKey { - RsaPublicKey { +) -> CryptographyResult { + let rsa = pkey.rsa()?; + check_public_key_components(rsa.e(), rsa.n())?; + Ok(RsaPublicKey { pkey: pkey.to_owned(), - } + }) } #[pyo3::pyfunction] @@ -795,23 +797,36 @@ impl RsaPrivateNumbers { } } +fn py_int_to_signed_bn( + py: pyo3::Python<'_>, + value: &pyo3::Bound<'_, pyo3::types::PyInt>, +) -> CryptographyResult { + let negative = value.lt(0)?; + let magnitude = value.call_method0(pyo3::intern!(py, "__abs__"))?; + let mut bn = utils::py_int_to_bn(py, &magnitude)?; + bn.set_negative(negative); + Ok(bn) +} + fn check_public_key_components( - e: &pyo3::Bound<'_, pyo3::types::PyInt>, - n: &pyo3::Bound<'_, pyo3::types::PyInt>, + e: &openssl::bn::BigNumRef, + n: &openssl::bn::BigNumRef, ) -> CryptographyResult<()> { - if n.lt(3)? { + let three = openssl::bn::BigNum::from_u32(3)?; + + if n.cmp(three.as_ref()).is_lt() { return Err(CryptographyError::from( pyo3::exceptions::PyValueError::new_err("n must be >= 3."), )); } - if e.lt(3)? || e.ge(n)? { + if e.cmp(three.as_ref()).is_lt() || e >= n { return Err(CryptographyError::from( pyo3::exceptions::PyValueError::new_err("e must be >= 3 and < n."), )); } - if e.bitand(1)?.eq(0)? { + if e.is_even() { return Err(CryptographyError::from( pyo3::exceptions::PyValueError::new_err("e must be odd."), )); @@ -835,13 +850,11 @@ impl RsaPublicNumbers { ) -> CryptographyResult { let _ = backend; - check_public_key_components(self.e.bind(py), self.n.bind(py))?; + let n = py_int_to_signed_bn(py, self.n.bind(py))?; + let e = py_int_to_signed_bn(py, self.e.bind(py))?; + check_public_key_components(&e, &n)?; - let rsa = openssl::rsa::Rsa::from_public_components( - utils::py_int_to_bn(py, self.n.bind(py))?, - utils::py_int_to_bn(py, self.e.bind(py))?, - ) - .unwrap(); + let rsa = openssl::rsa::Rsa::from_public_components(n, e).unwrap(); let pkey = openssl::pkey::PKey::from_rsa(rsa)?; Ok(RsaPublicKey { pkey }) } diff --git a/tests/hazmat/primitives/test_rsa.py b/tests/hazmat/primitives/test_rsa.py index 64c4a7c2253d..f7cbb374adb2 100644 --- a/tests/hazmat/primitives/test_rsa.py +++ b/tests/hazmat/primitives/test_rsa.py @@ -2328,7 +2328,9 @@ def test_private_numbers_invalid_types( @pytest.mark.parametrize( ("e", "n"), [ + (-1, 15), # public_exponent < 3 (7, 2), # modulus < 3 + (7, -1), # modulus < 3 (1, 15), # public_exponent < 3 (17, 15), # public_exponent > modulus (14, 15), # public_exponent not odd