From 244eef9a119d10a545561ff1959ea628cd3dae8a Mon Sep 17 00:00:00 2001 From: SeungjunLee Date: Wed, 27 May 2026 07:31:29 +0000 Subject: [PATCH 1/4] Init --- CMakeLists.txt | 2 + README.md | 4 - benchmark/benchmark.cpp | 43 +- benchmark/benchmark_blake3.cpp | 2 +- cmake/bundling.cmake | 65 +++ cmake/debConfig.cmake.in | 10 +- cmake/warnings.cmake | 4 +- examples/CMakeLists.txt | 3 + examples/EnDecryption-MultiSecret.cpp | 10 +- examples/EnDecryption-Real.cpp | 193 +++++++ examples/EnDecryption.cpp | 10 +- examples/KeyGeneration.cpp | 4 +- external/CMakeLists.txt | 30 +- include/deb/CKKSTypes.hpp | 60 +- include/deb/Decryptor.hpp | 12 +- include/deb/Encryptor.hpp | 36 +- include/deb/KeyGenerator.hpp | 17 +- include/deb/SecretKeyGenerator.hpp | 62 ++- include/deb/Types.hpp | 1 + include/deb/utils/ModArith.hpp | 116 +++- include/deb/utils/NTT.hpp | 241 ++++++-- include/deb/utils/NTTConfig.hpp | 101 ++++ prebuild/DebFBType.fbs | 2 +- prebuild/DebParamPreset.json | 19 +- src/CKKSTypes.cpp | 170 +++--- src/Decryptor.cpp | 140 +++-- src/Encryptor.cpp | 204 ++++--- src/KeyGenerator.cpp | 115 ++-- src/ModArith.cpp | 70 ++- src/NTT.cpp | 712 ++++++++++++++++++------ src/NTTConfig.cpp | 69 +++ src/SecretKeyGenerator.cpp | 70 ++- src/SeedGenerator.cpp | 1 - src/Serialize.cpp | 10 +- test/CMakeLists.txt | 2 +- test/EnDecryption-test.cpp | 190 +++++++ test/NTT-test.cpp | 758 +++++++++++++++++++++++++- test/TestBase.hpp | 108 +++- test/U32-test.cpp | 148 ++++- 39 files changed, 3155 insertions(+), 659 deletions(-) create mode 100644 cmake/bundling.cmake create mode 100644 examples/EnDecryption-Real.cpp create mode 100644 include/deb/utils/NTTConfig.hpp create mode 100644 src/NTTConfig.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index f064f9b..aad7181 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -31,6 +31,7 @@ set(CMAKE_EXPORT_COMPILE_COMMANDS ON) set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_EXTENSIONS OFF) +# set(CMAKE_CXX_VISIBILITY_PRESET "default") set(CMAKE_VISIBILITY_INLINES_HIDDEN OFF) set(CMAKE_INSTALL_PREFIX @@ -113,6 +114,7 @@ set(DEB_SRC src/KeyGenerator.cpp src/ModArith.cpp src/NTT.cpp + src/NTTConfig.cpp src/OmpUtils.cpp src/Preset.cpp src/RandomGenerator.cpp diff --git a/README.md b/README.md index 50cabba..a462130 100644 --- a/README.md +++ b/README.md @@ -53,10 +53,6 @@ cmake --build build --target install - `DEB_INSTALL_ALEA`: Install the alea library when installing deb. (default: OFF) - `DEB_INSTALL_FLATBUFFERS`: Install the flatbuffers library when installing deb. (default: OFF) - `DEB_RUNTIME_RESOURCE_CHECK`: Enable runtime resource check. (default: ON) -- `DEB_SERIALIZE_API`: Download FlatBuffers and enable serialization api. (default: ON) -- `DEB_SUPPORT_U64`: Compile u64 coefficient word type support. (default: ON) -- `DEB_SUPPORT_U32`: Compile u32 coefficient word type support. (default: OFF) - ## Testing diff --git a/benchmark/benchmark.cpp b/benchmark/benchmark.cpp index 0acf374..8d552ec 100644 --- a/benchmark/benchmark.cpp +++ b/benchmark/benchmark.cpp @@ -23,6 +23,7 @@ #include #include +#include #include #include @@ -60,7 +61,7 @@ static void bm_seckey_encryption(benchmark::State &state) { } SecretKey sk = SecretKeyGenerator::GenSecretKey(preset); - EncryptorT encryptor; + Encryptor encryptor(preset); Ciphertext ctxt(preset); for (auto _ : state) { @@ -82,7 +83,7 @@ static void bm_enckey_encryption(benchmark::State &state) { SecretKey sk = SecretKeyGenerator::GenSecretKey(preset); KeyGenerator keygen(preset); SwitchKey enckey = keygen.genEncKey(sk); - EncryptorT encryptor; + Encryptor encryptor(preset); Ciphertext ctxt(preset); for (auto _ : state) { @@ -101,13 +102,13 @@ template static void bm_decryption(benchmark::State &state) { } SecretKey sk = SecretKeyGenerator::GenSecretKey(preset); - EncryptorT encryptor; - DecryptorT decryptor; + Encryptor encryptor(preset); + Decryptor decryptor(preset); Ciphertext ctxt(preset); encryptor.encrypt(msg_v, sk, ctxt); for (auto _ : state) { - decryptor.decrypt(ctxt, sk, msg_v); + decryptor.decrypt(ctxt, sk, msg_v.data()); benchmark::DoNotOptimize(msg_v.data()); benchmark::ClobberMemory(); } @@ -122,21 +123,22 @@ template static void bm_decryption_inplace(benchmark::State &state) { } SecretKey sk = SecretKeyGenerator::GenSecretKey(preset); - EncryptorT encryptor; - DecryptorT decryptor; + Encryptor encryptor(preset); + Decryptor decryptor(preset); Ciphertext ctxt(preset); encryptor.encrypt(msg_v, sk, ctxt); - std::optional ctxt_copy; + std::optional ctxt_tmp; for (auto _ : state) { state.PauseTiming(); - ctxt_copy.emplace(ctxt.deepCopy()); + ctxt_tmp.emplace(ctxt.deepCopy()); state.ResumeTiming(); - decryptor.decryptInplace(ctxt_copy.value(), sk, msg_v.data()); + decryptor.decryptInplace(ctxt_tmp.value(), sk, msg_v.data()); benchmark::DoNotOptimize(msg_v.data()); benchmark::ClobberMemory(); } } + template static void bm_seckey_coeff_encryption(benchmark::State &state) { const Preset preset = T; @@ -147,7 +149,7 @@ static void bm_seckey_coeff_encryption(benchmark::State &state) { } SecretKey sk = SecretKeyGenerator::GenSecretKey(preset); - EncryptorT encryptor; + Encryptor encryptor(preset); Ciphertext ctxt(preset); for (auto _ : state) { @@ -168,7 +170,7 @@ static void bm_enckey_coeff_encryption(benchmark::State &state) { SecretKey sk = SecretKeyGenerator::GenSecretKey(preset); KeyGenerator keygen(preset); SwitchKey enckey = keygen.genEncKey(sk); - EncryptorT encryptor; + Encryptor encryptor(preset); Ciphertext ctxt(preset); for (auto _ : state) { @@ -187,14 +189,14 @@ static void bm_coeff_decryption(benchmark::State &state) { msg_v.push_back(gen_random_coeff(get_degree(preset))); } SecretKey sk = SecretKeyGenerator::GenSecretKey(preset); - EncryptorT encryptor(preset); - DecryptorT decryptor(preset); + Encryptor encryptor(preset); + Decryptor decryptor(preset); Ciphertext ctxt(preset); encryptor.encrypt(msg_v, sk, ctxt); for (auto _ : state) { - decryptor.decrypt(ctxt, sk, msg_v); + decryptor.decrypt(ctxt, sk, msg_v.data()); benchmark::DoNotOptimize(msg_v.data()); benchmark::ClobberMemory(); } @@ -209,22 +211,23 @@ static void bm_coeff_decryption_inplace(benchmark::State &state) { msg_v.push_back(gen_random_coeff(get_degree(preset))); } SecretKey sk = SecretKeyGenerator::GenSecretKey(preset); - EncryptorT encryptor(preset); - DecryptorT decryptor(preset); + Encryptor encryptor(preset); + Decryptor decryptor(preset); Ciphertext ctxt(preset); encryptor.encrypt(msg_v, sk, ctxt); - std::optional ctxt_copy; + std::optional ctxt_tmp; for (auto _ : state) { state.PauseTiming(); - ctxt_copy.emplace(ctxt.deepCopy()); + ctxt_tmp.emplace(ctxt.deepCopy()); state.ResumeTiming(); - decryptor.decryptInplace(ctxt_copy.value(), sk, msg_v.data()); + decryptor.decryptInplace(ctxt_tmp.value(), sk, msg_v.data()); benchmark::DoNotOptimize(msg_v.data()); benchmark::ClobberMemory(); } } + template static void bm_forward_ntt(benchmark::State &state) { utils::NTT ntt(degree, prime); diff --git a/benchmark/benchmark_blake3.cpp b/benchmark/benchmark_blake3.cpp index 0ab7ed0..a3ce1c3 100644 --- a/benchmark/benchmark_blake3.cpp +++ b/benchmark/benchmark_blake3.cpp @@ -29,7 +29,7 @@ std::random_device rd; std::mt19937 gen{rd()}; std::uniform_real_distribution dist{-1.0, 1.0}; -std::uniform_int_distribution dist_u64{0, UINT64_MAX}; +std::uniform_int_distribution dist_u64; using namespace deb; diff --git a/cmake/bundling.cmake b/cmake/bundling.cmake new file mode 100644 index 0000000..29eba00 --- /dev/null +++ b/cmake/bundling.cmake @@ -0,0 +1,65 @@ +# ~~~ +# Copyright 2026 CryptoLab, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ~~~ + +# Merge the object files of `dependency` target with that of `target`. Both +# should be static or object libraries; otherwise this is no-op. +function(merge_archive_if_static target dependency) + # Check if the target has object files + list(APPEND TYPES_HAVING_OBJECTS "STATIC_LIBRARY" "OBJECT_LIBRARY") + get_target_property(IS_STATIC ${target} TYPE) + if(NOT IS_STATIC IN_LIST TYPES_HAVING_OBJECTS) + return() + endif() + + # Check if the dependency is a target and has object files + if(NOT TARGET ${dependency}) + return() + endif() + get_target_property(IS_STATIC ${dependency} TYPE) + if(NOT IS_STATIC IN_LIST TYPES_HAVING_OBJECTS) + return() + endif() + + if(CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang") + add_custom_command( + TARGET ${target} + POST_BUILD + COMMAND rm -rf ${target}_objs && mkdir ${target}_objs + COMMAND rm -rf ${dependency}_objs && mkdir ${dependency}_objs + COMMAND ${CMAKE_COMMAND} -E chdir ${target}_objs ${CMAKE_AR} -x + $ + COMMAND ${CMAKE_COMMAND} -E chdir ${dependency}_objs ${CMAKE_AR} -x + $ + COMMAND ar -qcs $ ${target}_objs/*.o + ${dependency}_objs/*.o + COMMAND rm -rf ${target}_objs ${dependency}_objs + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) # DEPENDS ${target} + # ${dependency}) + elseif(CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") + add_custom_command( + TARGET ${target} + POST_BUILD + COMMAND lib.exe /OUT:$ $ + $ + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) # DEPENDS ${target} + # ${dependency}) + else() + message( + WARNING + "Failed merging ${target} target with ${dependency}: unsupported compiler" + ) + endif() +endfunction() diff --git a/cmake/debConfig.cmake.in b/cmake/debConfig.cmake.in index 8e870db..2430866 100644 --- a/cmake/debConfig.cmake.in +++ b/cmake/debConfig.cmake.in @@ -16,12 +16,16 @@ @PACKAGE_INIT@ -include("${CMAKE_CURRENT_LIST_DIR}/debTargets.cmake") - include(CMakeFindDependencyMacro) +find_dependency(alea REQUIRED) +find_dependency(flatbuffers REQUIRED) + set(DEB_BUILD_WITH_OMP @DEB_BUILD_WITH_OMP@) if(DEB_BUILD_WITH_OMP) - find_dependency(OpenMP REQUIRED COMPONENTS CXX C) + find_dependency(OpenMP REQUIRED COMPONENTS CXX) endif() unset(DEB_BUILD_WITH_OMP) + +include("${CMAKE_CURRENT_LIST_DIR}/debTargets.cmake") +check_required_components(deb) diff --git a/cmake/warnings.cmake b/cmake/warnings.cmake index caa1b93..88d7047 100644 --- a/cmake/warnings.cmake +++ b/cmake/warnings.cmake @@ -18,7 +18,7 @@ function(set_deb_warnings target) target_compile_options( ${target} PRIVATE - $<$,$,$>: + $<$,$,$>: -Wall -Wconversion -Wextra @@ -28,6 +28,6 @@ function(set_deb_warnings target) -Wunused -Wvla > - $<$: + $<$: /W4>) endfunction() diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index c847deb..9d18b47 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -23,6 +23,9 @@ set(${PROJECT_NAME}_example example_utils ${PROJECT_NAME}) add_executable(EnDecryption EnDecryption.cpp) target_link_libraries(EnDecryption PRIVATE ${${PROJECT_NAME}_example}) +add_executable(EnDecryption-Real EnDecryption-Real.cpp) +target_link_libraries(EnDecryption-Real PRIVATE ${${PROJECT_NAME}_example}) + add_executable(EnDecryption-MultiSecret EnDecryption-MultiSecret.cpp) target_link_libraries(EnDecryption-MultiSecret PRIVATE ${${PROJECT_NAME}_example}) diff --git a/examples/EnDecryption-MultiSecret.cpp b/examples/EnDecryption-MultiSecret.cpp index 44479af..64154a3 100644 --- a/examples/EnDecryption-MultiSecret.cpp +++ b/examples/EnDecryption-MultiSecret.cpp @@ -86,7 +86,7 @@ int main() { // Encrypt with iNTT output { - auto opt = EncryptOptions().NttOut(false); + auto opt = EncryptOptions().NTTOut(false); DebTimer::start("iNTT Output EnDecryption"); enc.encrypt(msg, sk, ctxt, opt); dec.decrypt(ctxt, sk, decrypted_msg); @@ -97,7 +97,7 @@ int main() { // Encrypt with all custom options { DebTimer::start("All Custom Options EnDecryption"); - enc.encrypt(msg, sk, ctxt, EncryptOptions().Scale(scale).Level(custom_level).NttOut(false)); + enc.encrypt(msg, sk, ctxt, EncryptOptions().Scale(scale).Level(custom_level).NTTOut(false)); dec.decrypt(ctxt, sk, decrypted_msg, scale); DebTimer::end(); std::cout << "log2 error = " << compareMessages(msg, decrypted_msg) << " bits" << std::endl; @@ -127,7 +127,7 @@ int main() { // Encrypt with all custom options { DebTimer::start("All Custom Options Coeff EnDecryption"); - enc.encrypt(cmsg, sk, ctxt, EncryptOptions().Scale(scale).Level(custom_level).NttOut(false)); + enc.encrypt(cmsg, sk, ctxt, EncryptOptions().Scale(scale).Level(custom_level).NTTOut(false)); dec.decrypt(ctxt, sk, decrypted_cmsg, scale); DebTimer::end(); std::cout << "log2 error = " << compareCoeffs(cmsg, decrypted_cmsg) << " bits" << std::endl; @@ -161,7 +161,7 @@ int main() { // Encrypt with all custom options { DebTimer::start("All Custom Options EnDecryption with EncKey"); - enc.encrypt(msg, ek, ctxt, EncryptOptions().Scale(scale).Level(custom_level).NttOut(false)); + enc.encrypt(msg, ek, ctxt, EncryptOptions().Scale(scale).Level(custom_level).NTTOut(false)); dec.decrypt(ctxt, sk, decrypted_msg, scale); DebTimer::end(); std::cout << "log2 error = " << compareMessages(msg, decrypted_msg) << " bits" << std::endl; @@ -170,7 +170,7 @@ int main() { // Encrypt with all custom options { DebTimer::start("All Custom Options Coeff EnDecryption with EncKey"); - enc.encrypt(cmsg, ek, ctxt, EncryptOptions().Scale(scale).Level(custom_level).NttOut(false)); + enc.encrypt(cmsg, ek, ctxt, EncryptOptions().Scale(scale).Level(custom_level).NTTOut(false)); dec.decrypt(ctxt, sk, decrypted_cmsg, scale); DebTimer::end(); std::cout << "log2 error = " << compareCoeffs(cmsg, decrypted_cmsg) << " bits" << std::endl; diff --git a/examples/EnDecryption-Real.cpp b/examples/EnDecryption-Real.cpp new file mode 100644 index 0000000..dd553a1 --- /dev/null +++ b/examples/EnDecryption-Real.cpp @@ -0,0 +1,193 @@ +/* +* Copyright 2026 CryptoLab, Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#include "ExampleUtils.hpp" + +#include +using namespace std; +using namespace deb; + +namespace { + +// Produces a Message sized to the polynomial degree (not num_slots) with +// imaginary parts zero — the input shape the real-encryption path expects. +Message generateRandomRealMessage(const Preset preset) { + static std::mt19937 eng(std::random_device{}()); + std::uniform_real_distribution dist(-1.0, 1.0); + Message msg(get_degree(preset)); + for (size_t i = 0; i < msg.size(); ++i) { + msg[i].real(dist(eng)); + msg[i].imag(0.0); + } + return msg; +} + +} // namespace + +int main() { + // Real encryption requires a preset whose primes are 4N-friendly. The + // bundled FGbD12L0 preset is one such configuration. + Preset preset = PRESET_EMPTY; + for (auto p : Presets) { + if (std::string(get_preset_name(p)) == "FGbD12L0") { + preset = p; + break; + } + } + if (preset == PRESET_EMPTY) { + std::cerr << "No real-friendly preset (FGbD12L0) found." << std::endl; + return -1; + } + std::cout << "Preset: " << get_preset_name(preset) << std::endl; + + // Build Encryptor / Decryptor with the real_encrypt / real_decrypt flag + // so each ModArith eagerly initializes its cyclic NTT object. + Encryptor enc(preset); + Decryptor dec(preset); + + // Real encryption uses full-degree real-valued messages (not num_slots). + Message msg(get_degree(preset)); // Message to be encrypted + Message decrypted_msg(get_degree(preset)); // Message to hold decrypted data + Ciphertext ctxt(preset); // Ciphertext to hold encrypted data + + // Secret key must be generated in CYCLIC NTT mode so its NTT-domain + // representation matches the ciphertext produced by real encryption. + SecretKey sk = SecretKeyGenerator::GenSecretKey(preset, std::nullopt, + utils::NTTType::CYCLIC); + + // Random real-valued message (imag == 0 for every slot). + msg = generateRandomRealMessage(preset); + + // --------------------------------------------------------------------- + // Message encryption/decryption with secret key (real) + // --------------------------------------------------------------------- + { + // Basic real encryption/decryption. + DebTimer::start("Basic Real EnDecryption"); + enc.encrypt(msg, sk, ctxt, EncryptOptions().RealEncrypt(true)); + dec.decrypt(ctxt, sk, decrypted_msg); + DebTimer::end(); + std::cout << "log2 error = " << compareMessage(msg, decrypted_msg) << " bits" << std::endl; + } + + // Scaled real encryption/decryption. + u64 base_bit = utils::bitWidth(get_primes(preset)[0]); + Real scale = std::pow(2.0, base_bit - 3); + { + auto opt = EncryptOptions().Scale(scale).RealEncrypt(true); + DebTimer::start("Scaled Real EnDecryption"); + enc.encrypt(msg, sk, ctxt, opt); + dec.decrypt(ctxt, sk, decrypted_msg, scale); + DebTimer::end(); + std::cout << "log2 error = " << compareMessage(msg, decrypted_msg) << " bits" << std::endl; + } + + // Real encryption at a custom level. + Size custom_level = get_encryption_level(preset) / 2; + { + auto opt = EncryptOptions().Level(custom_level).RealEncrypt(true); + DebTimer::start("Custom Level Real EnDecryption"); + enc.encrypt(msg, sk, ctxt, opt); + dec.decrypt(ctxt, sk, decrypted_msg); + DebTimer::end(); + std::cout << "log2 error = " << compareMessage(msg, decrypted_msg) << " bits" << std::endl; + } + + // Real encryption with iNTT (coefficient-domain) output. + { + auto opt = EncryptOptions().NTTOut(false).RealEncrypt(true); + DebTimer::start("iNTT Output Real EnDecryption"); + enc.encrypt(msg, sk, ctxt, opt); + dec.decrypt(ctxt, sk, decrypted_msg); + DebTimer::end(); + std::cout << "log2 error = " << compareMessage(msg, decrypted_msg) << " bits" << std::endl; + } + + // All custom options combined. + { + DebTimer::start("All Custom Options Real EnDecryption"); + enc.encrypt(msg, sk, ctxt, + EncryptOptions().Scale(scale).Level(custom_level).NTTOut(false).RealEncrypt(true)); + dec.decrypt(ctxt, sk, decrypted_msg, scale); + DebTimer::end(); + std::cout << "log2 error = " << compareMessage(msg, decrypted_msg) << " bits" << std::endl; + } + + // --------------------------------------------------------------------- + // Coefficient message encryption/decryption with secret key (real) + // --------------------------------------------------------------------- + CoeffMessage cmsg = generateRandomCoeffMessage(preset); + CoeffMessage decrypted_cmsg(preset); + + { + DebTimer::start("Basic Real Coeff EnDecryption"); + enc.encrypt(cmsg, sk, ctxt, EncryptOptions().RealEncrypt(true)); + dec.decrypt(ctxt, sk, decrypted_cmsg); + DebTimer::end(); + std::cout << "log2 error = " << compareCoeff(cmsg, decrypted_cmsg) << " bits" << std::endl; + } + + { + DebTimer::start("All Custom Options Real Coeff EnDecryption"); + enc.encrypt(cmsg, sk, ctxt, + EncryptOptions().Scale(scale).Level(custom_level).NTTOut(false).RealEncrypt(true)); + dec.decrypt(ctxt, sk, decrypted_cmsg, scale); + DebTimer::end(); + std::cout << "log2 error = " << compareCoeff(cmsg, decrypted_cmsg) << " bits" << std::endl; + } + + // --------------------------------------------------------------------- + // (Coefficient) Message encryption with encryption key (real) + // --------------------------------------------------------------------- + KeyGenerator keygen(preset); + SwitchKey ek = keygen.genEncKey(sk); + + { + DebTimer::start("Real Encryption with EncKey"); + enc.encrypt(msg, ek, ctxt, EncryptOptions().RealEncrypt(true)); + dec.decrypt(ctxt, sk, decrypted_msg); + DebTimer::end(); + std::cout << "log2 error = " << compareMessage(msg, decrypted_msg) << " bits" << std::endl; + } + + { + DebTimer::start("Real Coeff Encryption with EncKey"); + enc.encrypt(cmsg, ek, ctxt, EncryptOptions().RealEncrypt(true)); + dec.decrypt(ctxt, sk, decrypted_cmsg); + DebTimer::end(); + std::cout << "log2 error = " << compareCoeff(cmsg, decrypted_cmsg) << " bits" << std::endl; + } + + { + DebTimer::start("All Custom Options Real EnDecryption with EncKey"); + enc.encrypt(msg, ek, ctxt, + EncryptOptions().Scale(scale).Level(custom_level).NTTOut(false).RealEncrypt(true)); + dec.decrypt(ctxt, sk, decrypted_msg, scale); + DebTimer::end(); + std::cout << "log2 error = " << compareMessage(msg, decrypted_msg) << " bits" << std::endl; + } + + { + DebTimer::start("All Custom Options Real Coeff EnDecryption with EncKey"); + enc.encrypt(cmsg, ek, ctxt, + EncryptOptions().Scale(scale).Level(custom_level).NTTOut(false).RealEncrypt(true)); + dec.decrypt(ctxt, sk, decrypted_cmsg, scale); + DebTimer::end(); + std::cout << "log2 error = " << compareCoeff(cmsg, decrypted_cmsg) << " bits" << std::endl; + } + + return 0; +} diff --git a/examples/EnDecryption.cpp b/examples/EnDecryption.cpp index e7d0b3f..54a6542 100644 --- a/examples/EnDecryption.cpp +++ b/examples/EnDecryption.cpp @@ -83,7 +83,7 @@ int main() { } // Encrypt with iNTT output { - auto opt = EncryptOptions().NttOut(false); + auto opt = EncryptOptions().NTTOut(false); DebTimer::start("iNTT Output EnDecryption"); enc.encrypt(msg, sk, ctxt, opt); dec.decrypt(ctxt, sk, decrypted_msg); @@ -94,7 +94,7 @@ int main() { // Encrypt with all custom options { DebTimer::start("All Custom Options EnDecryption"); - enc.encrypt(msg, sk, ctxt, EncryptOptions().Scale(scale).Level(custom_level).NttOut(false)); + enc.encrypt(msg, sk, ctxt, EncryptOptions().Scale(scale).Level(custom_level).NTTOut(false)); dec.decrypt(ctxt, sk, decrypted_msg, scale); DebTimer::end(); std::cout << "log2 error = " << compareMessage(msg, decrypted_msg) << " bits" << std::endl; @@ -120,7 +120,7 @@ int main() { // Encrypt with all custom options { DebTimer::start("All Custom Options Coeff EnDecryption"); - enc.encrypt(cmsg, sk, ctxt, EncryptOptions().Scale(scale).Level(custom_level).NttOut(false)); + enc.encrypt(cmsg, sk, ctxt, EncryptOptions().Scale(scale).Level(custom_level).NTTOut(false)); dec.decrypt(ctxt, sk, decrypted_cmsg, scale); DebTimer::end(); std::cout << "log2 error = " << compareMessage(msg, decrypted_msg) << " bits" << std::endl; @@ -154,7 +154,7 @@ int main() { // Encrypt with all custom options { DebTimer::start("All Custom Options EnDecryption with EncKey"); - enc.encrypt(msg, ek, ctxt, EncryptOptions().Scale(scale).Level(custom_level).NttOut(false)); + enc.encrypt(msg, ek, ctxt, EncryptOptions().Scale(scale).Level(custom_level).NTTOut(false)); dec.decrypt(ctxt, sk, decrypted_msg, scale); DebTimer::end(); std::cout << "log2 error = " << compareMessage(msg, decrypted_msg) << " bits" << std::endl; @@ -163,7 +163,7 @@ int main() { // Encrypt with all custom options { DebTimer::start("All Custom Options Coeff EnDecryption with EncKey"); - enc.encrypt(cmsg, ek, ctxt, EncryptOptions().Scale(scale).Level(custom_level).NttOut(false)); + enc.encrypt(cmsg, ek, ctxt, EncryptOptions().Scale(scale).Level(custom_level).NTTOut(false)); dec.decrypt(ctxt, sk, decrypted_cmsg, scale); DebTimer::end(); std::cout << "log2 error = " << compareCoeff(cmsg, decrypted_cmsg) << " bits" << std::endl; diff --git a/examples/KeyGeneration.cpp b/examples/KeyGeneration.cpp index 7ed04f0..b228030 100644 --- a/examples/KeyGeneration.cpp +++ b/examples/KeyGeneration.cpp @@ -145,8 +145,8 @@ int main() { const Size pad_rank = 1U << (get_log_degree(preset) / 2); const Size num_p = get_num_p(preset); SwitchKey self_modkey(preset, SwitchKeyKind::SWK_MODPACK_SELF); - self_modkey.addAx(num_p, pad_rank, true); - self_modkey.addBx(num_p, pad_rank * get_num_secret(preset), true); + self_modkey.addAx(num_p, pad_rank, utils::NTTType::NEGACYCLIC); + self_modkey.addBx(num_p, pad_rank * get_num_secret(preset), utils::NTTType::NEGACYCLIC); DebTimer::start("Self ModPack Key Bundle Generation"); keygen.genModPackKeyBundleInplace(pad_rank, self_modkey, sk); // inplace keygen DebTimer::end(); diff --git a/external/CMakeLists.txt b/external/CMakeLists.txt index 3bc3494..a8ff811 100644 --- a/external/CMakeLists.txt +++ b/external/CMakeLists.txt @@ -29,21 +29,23 @@ cpmaddpackage( "ALEA_INSTALL ${DEB_INSTALL_ALEA}" "BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS}") -cpmaddpackage( - NAME - flatbuffers - GITHUB_REPOSITORY - google/flatbuffers - GIT_TAG - v25.2.10 - OPTIONS - "FLATBUFFERS_BUILD_TESTS OFF" - "FLATBUFFERS_INSTALL ${DEB_INSTALL_FLATBUFFERS}" - "FLATBUFFERS_BUILD_SHAREDLIB ${BUILD_SHARED_LIBS}") +if(DEB_SERIALIZE_API) + cpmaddpackage( + NAME + flatbuffers + GITHUB_REPOSITORY + google/flatbuffers + GIT_TAG + v25.2.10 + OPTIONS + "FLATBUFFERS_BUILD_TESTS OFF" + "FLATBUFFERS_INSTALL ${DEB_INSTALL_FLATBUFFERS}" + "FLATBUFFERS_BUILD_SHAREDLIB ${BUILD_SHARED_LIBS}") -set(flatbuffers_SOURCE_DIR - ${flatbuffers_SOURCE_DIR} - CACHE PATH "" FORCE) + set(flatbuffers_SOURCE_DIR + ${flatbuffers_SOURCE_DIR} + CACHE PATH "" FORCE) +endif() string(TOUPPER "${DEB_EXT_LIB_FOR_SECURE_ZERO}" _deb_secure_zero_backend) diff --git a/include/deb/CKKSTypes.hpp b/include/deb/CKKSTypes.hpp index 45e4488..7021b08 100644 --- a/include/deb/CKKSTypes.hpp +++ b/include/deb/CKKSTypes.hpp @@ -18,6 +18,7 @@ #include "Preset.hpp" #include "SeedGenerator.hpp" +#include "utils/NTTConfig.hpp" #include #include @@ -193,13 +194,26 @@ template class PolyUnitT { */ U prime() const noexcept; /** - * @brief Marks the coefficient representation as NTT or standard domain. + * @brief Marks the coefficient representation as NTT, recording both the + * cyclic kind and the root-finding algorithm used. */ - void setNTT(bool ntt_state) noexcept; + void setNTT(const utils::NTTType ntt_type, + const utils::NTTRootType root_type = + utils::getGlobalNTTRootType()) noexcept; /** * @brief Returns true if the unit is in NTT domain. */ bool isNTT() const noexcept; + /** + * @brief Returns the NTT type (NEGACYCLIC or CYCLIC). Meaningful only when + * isNTT() is true. + */ + utils::NTTType getNTTType() const noexcept; + /** + * @brief Returns the root-finding algorithm used when the NTT was applied. + * Meaningful only when isNTT() is true. + */ + utils::NTTRootType getNTTRootType() const noexcept; /** * @brief Number of coefficients available in this unit. */ @@ -223,7 +237,10 @@ template class PolyUnitT { private: U prime_; - bool ntt_state_; + // NTTType: NONNTT, NEGACYCLIC, CYCLIC + utils::NTTType ntt_type_; + // NTTRootType: MIN, DIRECT, CUSTOM + utils::NTTRootType ntt_root_type_; Size degree_; std::shared_ptr data_ptr_; }; @@ -261,9 +278,12 @@ template class PolynomialT { PolynomialT deepCopy(std::optional num_polyunit = std::nullopt) const; /** - * @brief Marks every unit as NTT or standard domain. + * @brief Marks every unit as NTT, recording the cyclic kind and root + * algorithm. */ - void setNTT(bool ntt_state) noexcept; + void setNTT(const utils::NTTType ntt_type, + const utils::NTTRootType root_type = + utils::getGlobalNTTRootType()) noexcept; /** * @brief Updates current level metadata. */ @@ -369,7 +389,9 @@ template class CiphertextT { /** * @brief Sets NTT state for every polynomial. */ - void setNTT(bool ntt_state); + void setNTT(const utils::NTTType ntt_type, + const utils::NTTRootType root_type = + utils::getGlobalNTTRootType()) noexcept; /** * @brief Updates level metadata. */ @@ -478,8 +500,10 @@ template class SecretKeyT { template class SwitchKeyT { public: SwitchKeyT() = delete; - explicit SwitchKeyT(Preset preset, const SwitchKeyKind type, - const std::optional rot_idx = std::nullopt); + explicit SwitchKeyT( + Preset preset, const SwitchKeyKind type, + const std::optional rot_idx = std::nullopt, + const utils::NTTType ntt_type = utils::NTTType::NEGACYCLIC); Preset preset() const noexcept; void setType(const SwitchKeyKind type) noexcept; @@ -488,14 +512,22 @@ template class SwitchKeyT { Size rotIdx() const noexcept; Size dnum() const noexcept; - void addAx(const Size num_polyunit, std::optional size = std::nullopt, - const bool ntt_state = false); + void + addAx(const Size num_polyunit, std::optional size = std::nullopt, + const utils::NTTType ntt_type = utils::NTTType::NONNTT, + const utils::NTTRootType root_type = utils::getGlobalNTTRootType()); void addAx(const PolynomialT &poly); - void addBx(const Size num_polyunit, std::optional size = std::nullopt, - const bool ntt_state = false); + void + addBx(const Size num_polyunit, std::optional size = std::nullopt, + const utils::NTTType ntt_type = utils::NTTType::NONNTT, + const utils::NTTRootType root_type = utils::getGlobalNTTRootType()); void addBx(const PolynomialT &poly); - void setAxNTT(bool ntt_state) noexcept; - void setBxNTT(bool ntt_state) noexcept; + void setAxNTT(const utils::NTTType ntt_type, + const utils::NTTRootType root_type = + utils::getGlobalNTTRootType()) noexcept; + void setBxNTT(const utils::NTTType ntt_type, + const utils::NTTRootType root_type = + utils::getGlobalNTTRootType()) noexcept; Size axSize() const noexcept; Size bxSize() const noexcept; std::vector> &getAx() noexcept; diff --git a/include/deb/Decryptor.hpp b/include/deb/Decryptor.hpp index 0d8ada6..80f6e2f 100644 --- a/include/deb/Decryptor.hpp +++ b/include/deb/Decryptor.hpp @@ -121,7 +121,11 @@ class DecryptorT : public PresetTraits { * @param scale Scaling factor for decoding. */ template - void decode(const PolynomialT &ptxt, MSG &msg, Real scale) const; + void decode(const PolynomialT &ptxt, MSG &msg, Real scale, + bool is_real) const; + + void changeNTTRootType(const utils::NTTRootType root_type); + utils::NTTRootType getNTTRootType() const; private: template @@ -174,9 +178,11 @@ class DecryptorT : public PresetTraits { const PolynomialT &ptxt, FCoeffMessage &coeff, Real scale) \ const; \ prefix template void DecryptorT::decode( \ - const PolynomialT &ptxt, Message &msg, Real scale) const; \ + const PolynomialT &ptxt, Message &msg, Real scale, \ + bool is_real) const; \ prefix template void DecryptorT::decode( \ - const PolynomialT &ptxt, FMessage &msg, Real scale) const; + const PolynomialT &ptxt, FMessage &msg, Real scale, \ + bool is_real) const; #define DECRYPT_TYPE_TEMPLATE(preset, u_type, prefix) \ prefix template class DecryptorT; \ diff --git a/include/deb/Encryptor.hpp b/include/deb/Encryptor.hpp index 0230b61..b059ce6 100644 --- a/include/deb/Encryptor.hpp +++ b/include/deb/Encryptor.hpp @@ -38,6 +38,8 @@ struct EncryptOptions { Real scale = 0; /**< Requested plaintext scale (0 = auto). */ Size level = utils::DEB_MAX_SIZE; /**< Encryption level override. */ bool ntt_out = true; /**< Whether ciphertext output stays in NTT form. */ + bool real_encrypt = + false; /**< Whether to use the real-encryption method. */ /** * @brief Sets the desired scale value. * @param s Requested scale. @@ -61,10 +63,24 @@ struct EncryptOptions { * @param n NTT flag. * @return Reference to this for chaining. */ + EncryptOptions &NTTOut(bool n) { + ntt_out = n; + return *this; + } + /** [deprecated] */ EncryptOptions &NttOut(bool n) { ntt_out = n; return *this; } + /** + * @brief Sets whether to use the real-encryption method. + * @param r Real-encryption flag. + * @return Reference to this for chaining. + */ + EncryptOptions &RealEncrypt(bool r) { + real_encrypt = r; + return *this; + } }; [[maybe_unused]] static EncryptOptions default_opt; @@ -82,7 +98,7 @@ class EncryptorT : public PresetTraits { public: /** * @brief Constructs an encryptor bound to a preset and optional RNG seed. - * @param target_preset Target preset. + * @param target_preset Target preset. Only specified if P is PRESET_EMPTY. * @param seeds Optional deterministic seed. */ explicit EncryptorT(std::optional seeds = std::nullopt); @@ -173,26 +189,32 @@ class EncryptorT : public PresetTraits { * @param msg Input message object. * @param ptxt Output plaintext polynomial. * @param size Number of PolyUnitT entries to encode. - * @param scale Scaling factor for embedding. + * @param opt Encryption options (scale and real_encrypt are required). */ void encode(const MSG &msg, PolynomialT &ptxt, const Size size, - const Real scale) const; + const EncryptOptions &opt) const; + + void changeNTTRootType(const utils::NTTRootType root_type); + utils::NTTRootType getNTTRootType() const; private: /** * @brief Samples a zero-one polynomial. * @param num_polyunit Number of PolyUnitT entries to sample. + * @param ntt_type NTT type to use for the sampled polynomial. */ - void sampleZO(const Size num_polyunit) const; + void sampleZO(const Size num_polyunit, const utils::NTTType ntt_type) const; /** * @brief Samples a Gaussian polynomial. * @param num_polyunit Number of PolyUnitT entries to sample. - * @param do_ntt Whether to apply NTT to the sampled polynomial. + * @param ntt_type NTT type to use for the sampled polynomial. */ - void sampleGaussian(const Size num_polyunit, const bool do_ntt) const; + void sampleGaussian(const Size num_polyunit, + const utils::NTTType ntt_type) const; std::shared_ptr rng_; + // compute buffers mutable PolynomialT ptxt_buffer_; mutable PolynomialT vx_buffer_; @@ -227,7 +249,7 @@ class EncryptorT : public PresetTraits { const Size size) const; \ prefix template void EncryptorT::encode( \ const msg_t &msg, PolynomialT &ptxt, const Size size, \ - const Real scale) const; + const EncryptOptions &opt) const; #define DECL_ENCRYPT_TEMPLATE(preset, u_type, prefix) \ prefix template class EncryptorT; \ diff --git a/include/deb/KeyGenerator.hpp b/include/deb/KeyGenerator.hpp index 3084ca9..e9fa9ab 100644 --- a/include/deb/KeyGenerator.hpp +++ b/include/deb/KeyGenerator.hpp @@ -18,6 +18,7 @@ #include "CKKSTypes.hpp" #include "utils/FFT.hpp" +#include "utils/NTT.hpp" #include "utils/PresetTraits.hpp" #include "utils/RandomGenerator.hpp" @@ -59,6 +60,7 @@ class KeyGeneratorT : public PresetTraits { KeyGeneratorT(const KeyGeneratorT &) = delete; ~KeyGeneratorT() = default; + void addNTTType(const utils::NTTType ntt_type); /** * @brief Generates a switching key that maps one polynomial basis to * another. @@ -69,18 +71,22 @@ class KeyGeneratorT : public PresetTraits { * @param ax_size Optional size hint for the ax buffer. * @param bx_size Optional size hint for the bx buffer. */ - void genSwitchingKey(const PolynomialT *from, const PolynomialT *to, - PolynomialT *ax, PolynomialT *bx, - const Size ax_size = 0, const Size bx_size = 0) const; + void genSwitchingKey( + const PolynomialT *from, const PolynomialT *to, + PolynomialT *ax, PolynomialT *bx, const Size ax_size = 0, + const Size bx_size = 0, + const utils::NTTType ntt_type = utils::NTTType::NEGACYCLIC) const; /** - * @brief Generates an encryption key. + * @brief Generates an encryption key. The NTT type is derived + * automatically from the NTT state stored in @p sk. * @param sk Secret key to generate public key. * @return Newly created encryption key. */ SwitchKeyT genEncKey(const SecretKeyT &sk) const; /** * @brief Generates an encryption key directly into an existing object. + * The NTT type is derived automatically from the NTT state stored in @p sk. * @param enckey Output storage for encryption key. * @param sk Secret key to generate public key. */ @@ -378,8 +384,7 @@ class KeyGeneratorT : public PresetTraits { void frobeniusMapInNTT(const PolynomialT &op, const i32 pow, PolynomialT res) const; - PolynomialT sampleGaussian(const Size num_polyunit, - bool do_ntt = false) const; + PolynomialT sampleGaussian(const Size num_polyunit) const; void sampleUniform(PolynomialT &poly) const; void computeConst(); diff --git a/include/deb/SecretKeyGenerator.hpp b/include/deb/SecretKeyGenerator.hpp index 450580d..5088cac 100644 --- a/include/deb/SecretKeyGenerator.hpp +++ b/include/deb/SecretKeyGenerator.hpp @@ -45,29 +45,42 @@ template class SecretKeyGeneratorT { /** * @brief Generates a new secret key. * @param seeds Optional deterministic RNG seeds. + * @param ntt_type NTT type used for the polynomial embedding (default: + * negacyclic, matching standard CKKS). Pass NTTType::CYCLIC for + * real-HEAAN mode. * @return Fresh secret key. */ SecretKeyT - genSecretKey(std::optional seeds = std::nullopt); + genSecretKey(std::optional seeds = std::nullopt, + utils::NTTType ntt_type = utils::NTTType::NEGACYCLIC); /** * @brief Generates a secret key into the provided object. * @param sk Output storage for secret key. * @param seeds Optional deterministic seed override. + * @param ntt_type NTT type for the polynomial embedding. */ - void genSecretKeyInplace(SecretKeyT &sk, - std::optional seeds = std::nullopt); + void + genSecretKeyInplace(SecretKeyT &sk, + std::optional seeds = std::nullopt, + utils::NTTType ntt_type = utils::NTTType::NEGACYCLIC); /** * @brief Builds a secret key from explicit coefficient data. * @param coeffs Pointer to coefficient array sized per preset degree. + * @param ntt_type NTT type for the polynomial embedding. * @return Secret key containing the provided coefficients. */ - SecretKeyT genSecretKeyFromCoeff(const i8 *coeffs); + SecretKeyT + genSecretKeyFromCoeff(const i8 *coeffs, + utils::NTTType ntt_type = utils::NTTType::NEGACYCLIC); /** * @brief Writes coefficient data into an existing secret key. * @param sk Output storage for secret key. * @param coeffs Pointer to coefficient array sized per preset degree. + * @param ntt_type NTT type for the polynomial embedding. */ - void genSecretKeyFromCoeffInplace(SecretKeyT &sk, const i8 *coeffs); + void genSecretKeyFromCoeffInplace( + SecretKeyT &sk, const i8 *coeffs, + utils::NTTType ntt_type = utils::NTTType::NEGACYCLIC); /** * @brief Generates secret-key coefficients deterministically. @@ -93,55 +106,68 @@ template class SecretKeyGeneratorT { * @param preset Target preset. * @param coeffs Pointer to coefficient data. * @param level Optional modulus level limitation. + * @param ntt_type NTT type for the polynomial embedding. * @return Secret key containing the embedded representation. */ static SecretKeyT ComputeEmbedding(const Preset preset, const i8 *coeffs, - std::optional level = std::nullopt); + std::optional level = std::nullopt, + utils::NTTType ntt_type = utils::NTTType::NEGACYCLIC); /** * @brief Writes an embedding into an existing secret key. * @param sk Output storage for secret key. * @param coeffs Source coefficient data. + * @param ntt_type NTT type for the polynomial embedding. */ - static void ComputeEmbeddingInplace(SecretKeyT &sk, const i8 *coeffs); + static void ComputeEmbeddingInplace( + SecretKeyT &sk, const i8 *coeffs, + utils::NTTType ntt_type = utils::NTTType::NEGACYCLIC); /** * @brief Convenience wrapper that constructs a generator and produces a * secret key. * @param preset Target preset. * @param seeds Optional deterministic seed. + * @param ntt_type NTT type for the polynomial embedding. * @return Newly generated secret key. */ static SecretKeyT GenSecretKey(const Preset preset, - std::optional seeds = std::nullopt); + std::optional seeds = std::nullopt, + utils::NTTType ntt_type = utils::NTTType::NEGACYCLIC); /** * @brief Generates a secret key in-place without instantiating a separate * generator. * @param sk Output storage for secret key. * @param seeds Optional deterministic seed. + * @param ntt_type NTT type for the polynomial embedding. */ static void GenSecretKeyInplace(SecretKeyT &sk, - std::optional seeds = std::nullopt); + std::optional seeds = std::nullopt, + utils::NTTType ntt_type = utils::NTTType::NEGACYCLIC); /** * @brief Builds a secret key from explicit coefficients without creating * an instance. * @param preset Target preset. * @param coeffs Pointer to coefficient data. + * @param ntt_type NTT type for the polynomial embedding. * @return Secret key containing the provided coefficients. */ - static SecretKeyT GenSecretKeyFromCoeff(const Preset preset, - const i8 *coeffs); + static SecretKeyT + GenSecretKeyFromCoeff(const Preset preset, const i8 *coeffs, + utils::NTTType ntt_type = utils::NTTType::NEGACYCLIC); /** * @brief Writes coefficient data into an existing secret key without * instantiating a generator. * @param sk Output storage for secret key. * @param coeffs Pointer to coefficient data. + * @param ntt_type NTT type for the polynomial embedding. */ - static void GenSecretKeyFromCoeffInplace(SecretKeyT &sk, - const i8 *coeffs); + static void GenSecretKeyFromCoeffInplace( + SecretKeyT &sk, const i8 *coeffs, + utils::NTTType ntt_type = utils::NTTType::NEGACYCLIC); private: const Preset preset_; @@ -151,21 +177,25 @@ template class SecretKeyGeneratorT { * @brief Ensures a secret key has fully allocated polynomial representations. * @param sk Secret key to complete. * @param level Optional modulus level restriction. + * @param ntt_type NTT type for the polynomial embedding. */ template void completeSecretKey(SecretKeyT &sk, - std::optional level = std::nullopt); + std::optional level = std::nullopt, + utils::NTTType ntt_type = utils::NTTType::NEGACYCLIC); // Explicit instantiation declarations #ifdef DEB_U64 using SecretKeyGenerator = SecretKeyGeneratorT; extern template class SecretKeyGeneratorT; -extern template void completeSecretKey(SecretKey &, std::optional); +extern template void completeSecretKey(SecretKey &, std::optional, + utils::NTTType); #endif #ifdef DEB_U32 using SecretKeyGenerator32 = SecretKeyGeneratorT; extern template class SecretKeyGeneratorT; -extern template void completeSecretKey(SecretKey32 &, std::optional); +extern template void completeSecretKey(SecretKey32 &, std::optional, + utils::NTTType); #endif diff --git a/include/deb/Types.hpp b/include/deb/Types.hpp index 97ae78e..54a8603 100644 --- a/include/deb/Types.hpp +++ b/include/deb/Types.hpp @@ -76,6 +76,7 @@ enum EncodingType { UNKNOWN, /**< No encoding context is available. */ COEFF, /**< Data is treated as coefficient representation. */ SLOT, /**< Data is treated as slot/complex representation. */ + REAL, /**< Data is treated as Real-HEAAN representation. */ }; /** diff --git a/include/deb/utils/ModArith.hpp b/include/deb/utils/ModArith.hpp index 98d3baa..668a344 100644 --- a/include/deb/utils/ModArith.hpp +++ b/include/deb/utils/ModArith.hpp @@ -23,6 +23,7 @@ #include #include +#include namespace deb::utils { @@ -57,16 +58,61 @@ template class ModArith : public DegreeTrait { * The prime is accepted as u64 for compatibility with preset tables; it is * narrowed to U internally. * + * @param degree Polynomial degree. * @param prime Prime modulus (must fit in U for correctness). + * @param default_ntt_is_cyclic Whether the default NTT should be cyclic. */ - explicit ModArith(u64 prime); - explicit ModArith(Size degree, u64 prime); + explicit ModArith(u64 prime, bool default_ntt_is_cyclic = false); + explicit ModArith(Size degree, u64 prime, + bool default_ntt_is_cyclic = false); /** * @brief Returns the modulus associated with this instance. */ inline U getPrime() const { return prime_; } + // Returns the negacyclic NTT or the cyclic NTT. + // The cyclic-true path requires the stored prime to be + // 4N-friendly; the negacyclic path requires 2N-friendly. + inline std::shared_ptr> getNTT(const NTTType ntt_type) const { + if (ntt_type == NTTType::CYCLIC) { + ensureCyclicNTT(); + return cyclic_ntt_; + } else if (ntt_type == NTTType::NEGACYCLIC) { + ensureNegacyclicNTT(); + return ntt_; + } else { + throw std::runtime_error("[ModArith::getNTT] Unsupported NTT type"); + } + }; + + /** + * @brief Returns the NTT root-finding algorithm this ModArith uses + * when constructing its NTT objects. + */ + inline NTTRootType getNTTRootType() const noexcept { return root_type_; } + + /** + * @brief Sets a per-instance NTT root-finding algorithm override. + * + * If the root type is changed from the current value, any existing NTT + * objects are discarded and will be rebuilt. This allows different ModArith + * instances to use different root-finding algorithms and/or custom roots at + * runtime. + */ + void setNTTRootType(NTTRootType rt = getGlobalNTTRootType()) { + if (root_type_ != rt) { + root_type_ = rt; + if (ntt_) { + ntt_.reset(); + ensureNegacyclicNTT(); // Rebuild with new root type. + } + if (cyclic_ntt_) { + cyclic_ntt_.reset(); + ensureCyclicNTT(); // Rebuild with new root type. + } + } + } /** * @brief Returns the precomputed Barrett ratio floor(2^64 / prime). * @@ -224,30 +270,42 @@ template class ModArith : public DegreeTrait { /** * @brief Applies the forward NTT, copying data when op and res differ. */ - inline void forwardNTT(U *op, U *res) const { + inline void forwardNTT(U *op, U *res, + const NTTType ntt_type = NTTType::NEGACYCLIC) const { if (op != res) std::copy_n(op, default_array_size_, res); - forwardNTT(res); + forwardNTT(res, ntt_type); } /** - * @brief Applies the forward NTT in-place. + * @brief Applies the forward NTT in-place. When @p cyclic is true the + * cyclic NTT object is constructed on first use; this requires the + * stored prime to be 4N-friendly. */ - inline void forwardNTT(U *op) const { ntt_->computeForward(op); } + inline void forwardNTT(U *op, + const NTTType ntt_type = NTTType::NEGACYCLIC) const { + getNTT(ntt_type)->computeForward(op); + } /** * @brief Applies the inverse NTT, copying data when op and res differ. */ - inline void backwardNTT(U *op, U *res) const { + inline void + backwardNTT(U *op, U *res, + const NTTType ntt_type = NTTType::NEGACYCLIC) const { if (op != res) std::copy_n(op, default_array_size_, res); - backwardNTT(res); + backwardNTT(res, ntt_type); } /** - * @brief Applies the inverse NTT in-place. + * @brief Applies the inverse NTT in-place. See forwardNTT for the + * lazy-construction note on @p cyclic. */ - inline void backwardNTT(U *op) const { ntt_->computeBackward(op); } + inline void + backwardNTT(U *op, const NTTType ntt_type = NTTType::NEGACYCLIC) const { + getNTT(ntt_type)->computeBackward(op); + } /** * @brief Returns the default vector size configured for this instance. @@ -280,7 +338,37 @@ template class ModArith : public DegreeTrait { u64 two_to_64_; // 2^64 mod prime u64 two_to_64_shoup_; // floor(two_to_64 * 2^64 / prime) - std::shared_ptr> ntt_ = nullptr; + // Per-instance root_type override. + NTTRootType root_type_; + + // Both NTT objects are built lazily on first use. This lets ModArith + // instances constructed for primes that satisfy only one of the + // congruence requirements (2N-friendly vs 4N-friendly) survive + // construction; the mismatched mode only fails when it is actually + // accessed. Construction tables (~degree-sized vectors) are also a + // non-trivial cost, so deferring them helps when the caller never + // exercises one of the modes. + // + // Held by base-class pointer so a user-registered NTTFactory can + // return any subclass (custom NTT backend) and still slot into both + // negacyclic and cyclic accessor paths. + mutable std::shared_ptr> ntt_ = nullptr; + mutable std::shared_ptr> cyclic_ntt_ = nullptr; + + void ensureNegacyclicNTT() const { + if (!ntt_) { + ntt_ = createNTT(default_array_size_, static_cast(prime_), + NTTType::NEGACYCLIC, getNTTRootType()); + } + } + + void ensureCyclicNTT() const { + if (!cyclic_ntt_) { + cyclic_ntt_ = + createNTT(default_array_size_, static_cast(prime_), + NTTType::CYCLIC, getNTTRootType()); + } + } }; /** @@ -289,6 +377,7 @@ template class ModArith : public DegreeTrait { template void forwardNTT(const std::vector> &modarith, PolynomialT &poly, Size num_polyunit = 0, + NTTType ntt_type = utils::NTTType::NEGACYCLIC, [[maybe_unused]] bool expected_ntt_state = false); /** @@ -297,6 +386,7 @@ void forwardNTT(const std::vector> &modarith, template void backwardNTT(const std::vector> &modarith, PolynomialT &poly, Size num_polyunit = 0, + NTTType ntt_type = utils::NTTType::NEGACYCLIC, [[maybe_unused]] bool expected_ntt_state = true); /** @@ -349,10 +439,10 @@ void constMulPoly(const std::vector> &modarith, prefix template class ModArith; \ prefix template void forwardNTT( \ const std::vector> &, PolynomialT &, \ - Size, bool); \ + Size, NTTType, bool); \ prefix template void backwardNTT( \ const std::vector> &, PolynomialT &, \ - Size, bool); \ + Size, NTTType, bool); \ prefix template void addPoly( \ const std::vector> &, \ const PolynomialT &, const PolynomialT &, \ diff --git a/include/deb/utils/NTT.hpp b/include/deb/utils/NTT.hpp index b9fdb43..c7b1881 100644 --- a/include/deb/utils/NTT.hpp +++ b/include/deb/utils/NTT.hpp @@ -18,88 +18,245 @@ #include "CKKSTypes.hpp" #include "utils/Basic.hpp" +#include "utils/NTTConfig.hpp" #include +#include +#include #include #include namespace deb::utils { -/** - * @brief Factorizes n into its distinct prime factors. - * @param s Output set receiving prime factors. - * @param n Number to factor. - */ void findPrimeFactors(std::set &s, u64 n); -/** - * @brief Finds a primitive root modulo prime. - * @param prime Prime modulus. - * @return Primitive root suitable for NTT. - */ u64 findPrimitiveRoot(u64 prime); +// --------------------------------------------------------------------------- +// NTT class hierarchy +// --------------------------------------------------------------------------- + /** - * @brief Implements forward and inverse number-theoretic transforms. + * @brief Abstract base for number-theoretic transform implementations. * - * @tparam U Coefficient word type (u32 or u64, default u64). - * All twiddle-factor storage and coefficient arrays use type U. - * The constructor always accepts u64 arguments (prime and degree) - * for compatibility with preset tables; values are narrowed to U - * internally when U = u32. - * Note: for U = u32, the prime must be < 2^30 so that intermediate - * butterfly values (up to 4·prime) fit in a u32. + * Two concrete subclasses ship below: NTT (negacyclic, evaluates at odd + * powers of a primitive 2N-th root) and NTT_C (cyclic, hem-compatible — + * evaluates the CI subring of Z_q[X]/ via an internal + * conversion()/inversion() pair built on a primitive 4N-th root). + * + * Users may also derive their own subclass to plug in a custom NTT (e.g. an + * SIMD variant or hardware-accelerated kernel) and register it via + * setNTTFactory(); ModArith and friends will pick it up through createNTT(). + * + * @tparam U Coefficient word type (u32 or u64). For U = u32 the prime must + * be < 2^30 so intermediate butterfly values (up to 4·prime) fit. */ -template class NTT { +template class NTT_base { public: - NTT() = default; - /** - * @brief Creates an NTT instance for a modulus and degree. - * @param degree Polynomial degree (must be a power of two). - * @param prime NTT-friendly prime (prime ≡ 1 mod 2·degree). - */ - NTT(u64 degree, u64 prime); + virtual ~NTT_base() = default; + + NTT_base(const NTT_base &) = delete; + NTT_base &operator=(const NTT_base &) = delete; + NTT_base(NTT_base &&) = default; + NTT_base &operator=(NTT_base &&) = default; /** * @brief Performs an in-place forward NTT on the supplied data. * @param op Pointer to coefficient array of length degree. */ - void computeForward(U *op) const; + virtual void computeForward(U *op) const = 0; /** * @brief Performs an in-place inverse NTT on the supplied data. * @param op Pointer to coefficient array of length degree. */ - void computeBackward(U *op) const; + virtual void computeBackward(U *op) const = 0; + + NTTType getType() const noexcept { return type_; } + NTTRootType getRootType() const noexcept { return root_type_; } + u64 getDegree() const noexcept { return degree_; } + U getPrime() const noexcept { return prime_; } + +protected: + NTT_base() = default; + NTT_base(u64 degree, u64 prime, NTTType type, NTTRootType root_type); + + U prime_{}; + U two_prime_{}; + u64 degree_{0}; + NTTType type_{NTTType::NONNTT}; + NTTRootType root_type_{NTTRootType::MIN}; +}; + +/** + * @brief Negacyclic NTT — evaluates at odd powers of a primitive 2N-th root. + * + * Requires `prime ≡ 1 mod 2·degree`. + */ +template class NTT : public NTT_base { +public: + NTT() = default; + + /** + * @brief Builds the negacyclic twiddle tables for (degree, prime). + * + * @param degree Polynomial degree (must be a power of two). + * @param prime NTT-friendly prime (prime ≡ 1 mod 2·degree). + * @param root_type Root-finding algorithm (defaults to current global). + * @throws std::runtime_error if parameters are invalid or the requested + * CUSTOM psi is missing. + */ + NTT(u64 degree, u64 prime, NTTRootType root_type = getGlobalNTTRootType()); + + void computeForward(U *op) const override; + void computeBackward(U *op) const override; + +private: + using NTT_base::prime_; + using NTT_base::two_prime_; + using NTT_base::degree_; + + std::vector psi_rev_; + std::vector psi_inv_rev_; + std::vector psi_rev_shoup_; + std::vector psi_inv_rev_shoup_; + + U degree_inv_{}; + U degree_inv_barrett_{}; + U degree_inv_w_{}; + U degree_inv_w_barrett_{}; +}; + +/** + * @brief Cyclic, hem-compatible NTT. + * + * Picks a primitive 4N-th root ζ and builds the layered twiddle tables that + * hem::NTT::CYCLIC uses, so deb and hem land on the same NTT-domain bins + * for the same (prime, degree). + * + * Requires `prime ≡ 1 mod 4·degree`. + */ +template class NTT_C : public NTT_base { +public: + NTT_C() = default; + + /** + * @brief Builds the cyclic twiddle tables for (degree, prime). + * + * @param degree Polynomial degree (must be a power of two). + * @param prime NTT-friendly prime (prime ≡ 1 mod 4·degree). + * @param root_type Root-finding algorithm (defaults to current global). + * For CUSTOM, a 4N-th root must be registered under the + * key (2*degree, prime). + * @throws std::runtime_error if parameters are invalid or the requested + * CUSTOM zeta is missing. + */ + NTT_C(u64 degree, u64 prime, + NTTRootType root_type = getGlobalNTTRootType()); + + void computeForward(U *op) const override; + void computeBackward(U *op) const override; private: - U prime_; - U two_prime_; - u64 degree_; ///< degree stays u64 (used as loop bounds / array sizes) + using NTT_base::prime_; + using NTT_base::two_prime_; + using NTT_base::degree_; - // Roots of unity (bit-reversed order), stored as U std::vector psi_rev_; std::vector psi_inv_rev_; - std::vector - psi_rev_shoup_; ///< Shoup precomputed: floor(psi · 2^bits / prime) + std::vector psi_rev_shoup_; std::vector psi_inv_rev_shoup_; - // Precomputed values for the last (combined degree-inverse) step of iNTT - U degree_inv_; - U degree_inv_barrett_; - U degree_inv_w_; - U degree_inv_w_barrett_; + U degree_inv_{}; + U degree_inv_barrett_{}; + U degree_inv_w_{}; + U degree_inv_w_barrett_{}; + + // CI <-> cyclic ring conversion tables (powers of the 4N-th root ζ). + std::vector roots_; + std::vector roots_inv_; + std::vector roots_shoup_; + std::vector roots_inv_shoup_; - void computeForwardNativeSingleStep(U *op, u64 t) const; - void computeBackwardNativeSingleStep(U *op, u64 t) const; - void computeBackwardNativeLast(U *op) const; + void conversion(U *op) const; + void inversion(U *op) const; }; +/** + * @brief Factory returning an NTT_base owning the appropriate concrete type. + * + * Use this only when the transform type must be chosen at runtime. When the + * type is known at compile time, construct NTT or NTT_C directly. + */ +template +std::unique_ptr> +makeNTT(u64 degree, u64 prime, NTTType type, + NTTRootType root_type = getGlobalNTTRootType()); + // Explicit instantiation declarations #ifdef DEB_U32 +extern template class NTT_base; extern template class NTT; +extern template class NTT_C; +extern template std::unique_ptr> makeNTT(u64, u64, NTTType, + NTTRootType); #endif #ifdef DEB_U64 +extern template class NTT_base; extern template class NTT; +extern template class NTT_C; +extern template std::unique_ptr> makeNTT(u64, u64, NTTType, + NTTRootType); +#endif + +// --------------------------------------------------------------------------- +// User-overridable NTT factory +// --------------------------------------------------------------------------- + +/** + * @brief Factory signature for producing NTT instances. + * + * Custom implementations should construct an object that satisfies + * NTT_base for the given (degree, prime, type, root_type) and return it + * wrapped in a shared_ptr. Return an empty shared_ptr to delegate to the + * default implementation for that particular combination (useful when the + * custom backend only handles a subset of NTTType / NTTRootType values). + */ +template +using NTTFactory = std::function>( + u64 degree, u64 prime, NTTType type, NTTRootType root_type)>; + +/** + * @brief Registers a custom NTT factory for word type @p U. + * + * Pass an empty/`nullptr` factory to revert to the default (makeNTT). + * Thread-safe. + */ +template void setNTTFactory(NTTFactory factory); + +/** + * @brief Builds an NTT instance, dispatching through the registered factory + * (if any) or falling back to the default makeNTT(). + * + * Used internally by ModArith and friends so that a registered custom NTT + * is honored everywhere a transform object is needed. If the registered + * factory returns an empty shared_ptr the call also falls through to the + * default — letting partial-coverage factories ignore types they don't + * implement. + */ +template +std::shared_ptr> +createNTT(u64 degree, u64 prime, NTTType type, + NTTRootType root_type = getGlobalNTTRootType()); + +#ifdef DEB_U32 +extern template void setNTTFactory(NTTFactory); +extern template std::shared_ptr> createNTT(u64, u64, NTTType, + NTTRootType); +#endif +#ifdef DEB_U64 +extern template void setNTTFactory(NTTFactory); +extern template std::shared_ptr> createNTT(u64, u64, NTTType, + NTTRootType); #endif } // namespace deb::utils diff --git a/include/deb/utils/NTTConfig.hpp b/include/deb/utils/NTTConfig.hpp new file mode 100644 index 0000000..b5e61eb --- /dev/null +++ b/include/deb/utils/NTTConfig.hpp @@ -0,0 +1,101 @@ +/* + * Copyright 2026 CryptoLab, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "Types.hpp" + +namespace deb::utils { +/** + * @brief Selects the transform type used when constructing an NTT. + * + * - NONNTT: Non NTT state. No forward NTT is applied. + * + * - NEGACYCLIC: Computes polynomial multiplication modulo X^N + 1. + * Requires prime ≡ 1 mod 2·degree. Uses a primitive 2N-th + * root of unity (psi) throughout. + * + * - CYCLIC: Computes polynomial multiplication modulo X^N + 1. + * Also requires prime ≡ 1 mod 4·degree. + */ +enum class NTTType : int { NONNTT = 0, NEGACYCLIC = 1, CYCLIC = 2 }; + +/** + * @brief Selects the primitive-root algorithm used when constructing an NTT. + * + * - MIN: finds the global primitive root of (Z/pZ)* and then selects + * the smallest primitive 2N-th root of unity among all conjugates. + * - DIRECT: derives the primitive 2N-th root directly from the 2-adic + * structure of (p-1) without min-root selection. + * - CUSTOM: the user supplies the primitive 2N-th root of unity via + * registerCustomPsi() before constructing the NTT object. + * This applies to both NEGACYCLIC and CYCLIC transform types. + */ +enum class NTTRootType : int { MIN = 0, DIRECT = 1, CUSTOM = 2 }; + +/** + * @brief Sets the global NTT root-finding algorithm. + */ +void setGlobalNTTRootType(NTTRootType type); + +/** + * @brief Returns the currently active global NTT root-finding algorithm. + */ +NTTRootType getGlobalNTTRootType(); + +/** + * @brief RAII guard that temporarily changes the global NTT root type and + * restores the previous value on destruction. + */ +class ScopedNTTRootType { +public: + explicit ScopedNTTRootType(NTTRootType type) + : prev_(getGlobalNTTRootType()) { + setGlobalNTTRootType(type); + } + ~ScopedNTTRootType() { setGlobalNTTRootType(prev_); } + + ScopedNTTRootType(const ScopedNTTRootType &) = delete; + ScopedNTTRootType &operator=(const ScopedNTTRootType &) = delete; + +private: + NTTRootType prev_; +}; + +/** + * @brief Registers a custom primitive root of unity for use with + * NTTRootType::CUSTOM. + * + * For the negacyclic NTT (`NTT`), call with `(degree, prime, psi)` where + * psi is a primitive 2·degree-th root of unity. + * + * For the cyclic NTT (`NTT_C`), call with `(2*degree, prime, zeta)` where + * zeta is a primitive 4·degree-th root of unity (the doubled key avoids + * collision with the negacyclic entry for the same (degree, prime)). + * + * The function validates that psi^(2*degree) ≡ 1 and psi^degree ≢ 1 + * (mod prime) before storing it. Thread-safe. + */ +void registerCustomPsi(u64 degree, u64 prime, u64 psi); + +namespace detail { +// Looks up a CUSTOM psi from the global registry. registry_key_degree is the +// key stored by registerCustomPsi() — NEGACYCLIC uses (degree, prime), CYCLIC +// uses (2*degree, prime). Throws std::runtime_error if missing. +u64 lookupCustomPsi(u64 registry_key_degree, u64 prime, const char *ctx); +} // namespace detail + +} // namespace deb::utils diff --git a/prebuild/DebFBType.fbs b/prebuild/DebFBType.fbs index 287d4f2..bec1644 100644 --- a/prebuild/DebFBType.fbs +++ b/prebuild/DebFBType.fbs @@ -33,7 +33,7 @@ table Coeff32 { table PolyUnit { prime:uint64; degree:uint32; - ntt_state:bool; + ntt_info:int32; array:[uint64]; } diff --git a/prebuild/DebParamPreset.json b/prebuild/DebParamPreset.json index 1672b1b..b74fb1e 100644 --- a/prebuild/DebParamPreset.json +++ b/prebuild/DebParamPreset.json @@ -22,6 +22,10 @@ "PRIMES": [ 288230376147386369, 2251799810670593 + ], + "SCALE_FACTORS": [ + 40.0, + 40.0 ] }, { @@ -97,20 +101,23 @@ "LOG_DEGREE": 12, "NUM_BASE": 1, "NUM_QP": 1, - "NUM_TP": 1, + "NUM_TP": 2, "HWT": 2730, "ENC_LEVEL": 1, "BITS": [ 32 ], "PRIMES": [ - 2147377153, - 2147352577, - 274877816833 + 1073692673, + 1073668097, + 8380417, + 8273921 ], "SCALE_FACTORS": [ - 29.415, - 29.415 + 30.0, + 30.0, + 23.0, + 23.0 ] } ] diff --git a/src/CKKSTypes.cpp b/src/CKKSTypes.cpp index d105f7a..c9efb54 100644 --- a/src/CKKSTypes.cpp +++ b/src/CKKSTypes.cpp @@ -15,31 +15,58 @@ */ #include "CKKSTypes.hpp" +#include "utils/NTT.hpp" #include + #if defined(_WIN32) #include #endif -namespace deb { +namespace { +inline std::size_t align_up(std::size_t size) { +#if DEB_ALINAS_LEN == 0 + return size; +#else + return (size + DEB_ALINAS_LEN - 1) & ~(DEB_ALINAS_LEN - 1); +#endif +} -#if DEB_ALINAS_LEN != 0 -inline void *deb_aligned_alloc(size_t alignment, size_t size) { +inline void *deb_malloc(const std::size_t bytes) { + if (bytes == 0) { + return nullptr; + } +#if DEB_ALINAS_LEN == 0 + return std::malloc(bytes); +#else #if defined(_WIN32) - return _aligned_malloc(size, alignment); + return _aligned_malloc(bytes, DEB_ALINAS_LEN); #else - return std::aligned_alloc(alignment, size); + return std::aligned_alloc(DEB_ALINAS_LEN, align_up(bytes)); +#endif #endif } -inline void deb_aligned_free(void *ptr) { -#if defined(_WIN32) +inline void deb_free(void *ptr) { + if (ptr == nullptr) { + return; + } +#if DEB_ALINAS_LEN != 0 && defined(_WIN32) _aligned_free(ptr); #else std::free(ptr); #endif + return; } -#endif + +template +inline std::shared_ptr make_aligned_shared_array(std::size_t size) { + return std::shared_ptr(static_cast(deb_malloc(sizeof(U) * size)), + [](U *p) { deb_free(p); }); +} +} // namespace + +namespace deb { //// --------------------------------------------------------------------- //// Implementation of Message @@ -83,39 +110,28 @@ MESSAGE_TYPE_TEMPLATE() // --------------------------------------------------------------------- template PolyUnitT::PolyUnitT(const Preset preset, const Size level, const bool alloc) - : prime_(static_cast(get_primes(preset)[level])), ntt_state_(false), + : prime_(static_cast(get_primes(preset)[level])), + ntt_type_(utils::NTTType::NONNTT), + ntt_root_type_(utils::getGlobalNTTRootType()), degree_(get_degree(preset)) { if (!alloc) { data_ptr_ = nullptr; degree_ = 0; return; } -#if DEB_ALINAS_LEN == 0 - data_ptr_ = - std::shared_ptr(new U[degree_], std::default_delete()); -#else - auto *buf = static_cast( - deb_aligned_alloc(DEB_ALINAS_LEN, sizeof(U) * degree_)); - data_ptr_ = std::shared_ptr(buf, [](U *p) { deb_aligned_free(p); }); -#endif + data_ptr_ = make_aligned_shared_array(degree_); } template PolyUnitT::PolyUnitT(u64 prime, Size degree, const bool alloc) - : prime_(static_cast(prime)), ntt_state_(false), degree_(degree) { + : prime_(static_cast(prime)), ntt_type_(utils::NTTType::NONNTT), + ntt_root_type_(utils::getGlobalNTTRootType()), degree_(degree) { if (!alloc) { data_ptr_ = nullptr; degree_ = 0; return; } -#if DEB_ALINAS_LEN == 0 - data_ptr_ = - std::shared_ptr(new U[degree_], std::default_delete()); -#else - auto *buf = static_cast( - deb_aligned_alloc(DEB_ALINAS_LEN, sizeof(U) * degree_)); - data_ptr_ = std::shared_ptr(buf, [](U *p) { deb_aligned_free(p); }); -#endif + data_ptr_ = make_aligned_shared_array(degree_); } template PolyUnitT PolyUnitT::deepCopy() const { @@ -126,7 +142,8 @@ template PolyUnitT PolyUnitT::deepCopy() const { copy[i] = (*this)[i]; } } - copy.setNTT(ntt_state_); + copy.ntt_type_ = ntt_type_; + copy.ntt_root_type_ = ntt_root_type_; return copy; } @@ -136,12 +153,24 @@ template void PolyUnitT::setPrime(u64 prime) noexcept { template U PolyUnitT::prime() const noexcept { return prime_; } -template void PolyUnitT::setNTT(bool ntt_state) noexcept { - ntt_state_ = ntt_state; +template +void PolyUnitT::setNTT(const utils::NTTType ntt_type, + const utils::NTTRootType root_type) noexcept { + ntt_type_ = ntt_type; + ntt_root_type_ = root_type; } template bool PolyUnitT::isNTT() const noexcept { - return ntt_state_; + return ntt_type_ != utils::NTTType::NONNTT; +} + +template utils::NTTType PolyUnitT::getNTTType() const noexcept { + return ntt_type_; +} + +template +utils::NTTRootType PolyUnitT::getNTTRootType() const noexcept { + return ntt_root_type_; } template Size PolyUnitT::degree() const noexcept { @@ -163,14 +192,7 @@ PolynomialT::PolynomialT(const Preset preset, const bool full_level) { const Size degree = get_degree(preset); const Size num_poly = full_level ? get_num_p(preset) : get_encryption_level(preset) + 1; -#if DEB_ALINAS_LEN == 0 - dealloc_ptr_ = std::shared_ptr(new U[num_poly * degree], - std::default_delete()); -#else - auto *buf = static_cast( - deb_aligned_alloc(DEB_ALINAS_LEN, sizeof(U) * num_poly * degree)); - dealloc_ptr_ = std::shared_ptr(buf, [](U *p) { deb_aligned_free(p); }); -#endif + dealloc_ptr_ = make_aligned_shared_array(num_poly * degree); for (Size l = 0; l < num_poly; ++l) { polyunits_.emplace_back(preset, l, false); polyunits_[l].setData(dealloc_ptr_.get() + l * degree, degree); @@ -180,14 +202,7 @@ PolynomialT::PolynomialT(const Preset preset, const bool full_level) { template PolynomialT::PolynomialT(const Preset preset, const Size custom_size) { const Size degree = get_degree(preset); -#if DEB_ALINAS_LEN == 0 - dealloc_ptr_ = std::shared_ptr(new U[custom_size * degree], - std::default_delete()); -#else - auto *buf = static_cast( - deb_aligned_alloc(DEB_ALINAS_LEN, sizeof(U) * custom_size * degree)); - dealloc_ptr_ = std::shared_ptr(buf, [](U *p) { deb_aligned_free(p); }); -#endif + dealloc_ptr_ = make_aligned_shared_array(custom_size * degree); for (Size l = 0; l < custom_size; ++l) { polyunits_.emplace_back(preset, l, false); polyunits_[l].setData(dealloc_ptr_.get() + l * degree, degree); @@ -211,21 +226,15 @@ PolynomialT::deepCopy(std::optional num_polyunit) const { PolynomialT copy(*this, 0, 0); copy.polyunits_.clear(); if (dealloc_ptr_ != nullptr) { -#if DEB_ALINAS_LEN == 0 - copy.dealloc_ptr_ = std::shared_ptr( - new U[num_polyunit_val * polyunits_[0].degree()], - std::default_delete()); -#else - auto *buf = static_cast( - deb_aligned_alloc(DEB_ALINAS_LEN, sizeof(U) * num_polyunit_val * - polyunits_[0].degree())); - copy.dealloc_ptr_ = - std::shared_ptr(buf, [](U *p) { deb_aligned_free(p); }); -#endif + copy.dealloc_ptr_ = make_aligned_shared_array( + num_polyunit_val * polyunits_[0].degree()); for (Size i = 0; i < num_polyunit_val; ++i) { copy.polyunits_.emplace_back(polyunits_[i].prime(), polyunits_[i].degree(), true); - copy.polyunits_[i].setNTT(polyunits_[i].isNTT()); + if (polyunits_[i].isNTT()) { + copy.polyunits_[i].setNTT(polyunits_[i].getNTTType(), + polyunits_[i].getNTTRootType()); + } copy.polyunits_[i].setData(copy.dealloc_ptr_.get() + i * polyunits_[i].degree(), polyunits_[i].degree()); @@ -242,10 +251,11 @@ PolynomialT::deepCopy(std::optional num_polyunit) const { return copy; } -template void PolynomialT::setNTT(bool ntt_state) noexcept { - for (auto &poly : polyunits_) { - poly.setNTT(ntt_state); - } +template +void PolynomialT::setNTT(const utils::NTTType ntt_type, + const utils::NTTRootType root_type) noexcept { + for (auto &poly : polyunits_) + poly.setNTT(ntt_type, root_type); } template void PolynomialT::setLevel(Preset preset, Size level) { @@ -331,9 +341,11 @@ template bool CiphertextT::isCoeff() const noexcept { return encoding_ == COEFF; } -template void CiphertextT::setNTT(bool ntt_state) { +template +void CiphertextT::setNTT(const utils::NTTType ntt_type, + const utils::NTTRootType root_type) noexcept { for (auto &poly : polys_) { - poly.setNTT(ntt_state); + poly.setNTT(ntt_type, root_type); } } @@ -480,15 +492,17 @@ void SecretKeyT::allocPolys(std::optional num_polyunit) { // --------------------------------------------------------------------- template SwitchKeyT::SwitchKeyT(Preset preset, const SwitchKeyKind type, - const std::optional rot_idx) + const std::optional rot_idx, + const utils::NTTType ntt_type) : preset_(preset), type_(type), rot_idx_(rot_idx), dnum_(get_gadget_rank(preset)) { if (type_ == SWK_MODPACK_SELF || type_ == SWK_GENERIC) { return; } const Size size = (type_ == SWK_ENC) ? 1 : dnum_; - addAx(get_num_p(preset), size, true); - addBx(get_num_p(preset), size * get_num_secret(preset), true); + addAx(get_num_p(preset), size, ntt_type, utils::getGlobalNTTRootType()); + addBx(get_num_p(preset), size * get_num_secret(preset), ntt_type, + utils::getGlobalNTTRootType()); } template Preset SwitchKeyT::preset() const noexcept { @@ -520,12 +534,13 @@ template Size SwitchKeyT::dnum() const noexcept { template void SwitchKeyT::addAx(const Size num_polyunit, std::optional size, - const bool ntt_state) { + const utils::NTTType ntt_type, + const utils::NTTRootType root_type) { const auto num_poly = size.value_or(1); for (Size i = 0; i < num_poly; ++i) { ax_.emplace_back(preset_, num_polyunit); } - setAxNTT(ntt_state); + setAxNTT(ntt_type, root_type); } template void SwitchKeyT::addAx(const PolynomialT &poly) { @@ -534,27 +549,32 @@ template void SwitchKeyT::addAx(const PolynomialT &poly) { template void SwitchKeyT::addBx(const Size num_polyunit, std::optional size, - const bool ntt_state) { + const utils::NTTType ntt_type, + const utils::NTTRootType root_type) { const auto num_poly = size.value_or(dnum_ * get_num_secret(preset_)); for (Size i = 0; i < num_poly; ++i) { bx_.emplace_back(preset_, num_polyunit); } - setBxNTT(ntt_state); + setBxNTT(ntt_type, root_type); } template void SwitchKeyT::addBx(const PolynomialT &poly) { bx_.push_back(poly); } -template void SwitchKeyT::setAxNTT(bool ntt_state) noexcept { +template +void SwitchKeyT::setAxNTT(const utils::NTTType ntt_type, + const utils::NTTRootType root_type) noexcept { for (auto &poly : ax_) { - poly.setNTT(ntt_state); + poly.setNTT(ntt_type, root_type); } } -template void SwitchKeyT::setBxNTT(bool ntt_state) noexcept { +template +void SwitchKeyT::setBxNTT(const utils::NTTType ntt_type, + const utils::NTTRootType root_type) noexcept { for (auto &poly : bx_) { - poly.setNTT(ntt_state); + poly.setNTT(ntt_type, root_type); } } diff --git a/src/Decryptor.cpp b/src/Decryptor.cpp index 8764599..0f3be5f 100644 --- a/src/Decryptor.cpp +++ b/src/Decryptor.cpp @@ -28,20 +28,21 @@ namespace deb { constexpr Size MAX_DECRYPT_SIZE = 2; -template -DecryptorT::DecryptorT() : PresetTraits(preset), fft_(degree) { +template DecryptorT::DecryptorT() : DecryptorT(P) { if constexpr (P == PRESET_EMPTY) { throw std::runtime_error("[Decryptor] Preset template must be " "specified when preset is not given"); } - for (Size i = 0; i < MAX_DECRYPT_SIZE; ++i) { - modarith.emplace_back(primes[i]); - } } template DecryptorT::DecryptorT(const Preset target_preset) - : PresetTraits(target_preset), fft_(degree) { + : PresetTraits(target_preset), fft_(degree << 1) { + if (target_preset == PRESET_EMPTY) { + throw std::runtime_error( + "[Decryptor] Target preset must be " + "specified when PRESET_EMPTY template is used"); + } for (Size i = 0; i < MAX_DECRYPT_SIZE; ++i) { modarith.emplace_back(degree, primes[i]); } @@ -67,7 +68,6 @@ void DecryptorT::decrypt(const CiphertextT &ctxt, ctxt.deepCopy(std::min(ctxt[0].size(), MAX_DECRYPT_SIZE)); decryptInplace(ctxt_copy, sk, msg, scale); } - template template void DecryptorT::decryptInplace(CiphertextT &ctxt, @@ -80,6 +80,7 @@ void DecryptorT::decryptInplace(CiphertextT &ctxt, deb_assert(sk[0].size() >= ctxt[0].size(), "[Decryptor::decrypt] Level of secret key must be greater than " "or equal to ciphertext level"); + if (scale == 0) scale = std::pow(2.0, -scale_factors[ctxt[0].size() - 1]); else @@ -89,33 +90,25 @@ void DecryptorT::decryptInplace(CiphertextT &ctxt, static_cast(ctxt[0].size() * (degree >> 10)); utils::setOmpThreadLimit(max_num_threads); + const bool is_real = ctxt.encoding() == REAL; + const utils::NTTType ntt_type = + is_real ? utils::NTTType::CYCLIC : utils::NTTType::NEGACYCLIC; const Size num_polyunit = std::min(ctxt[0].size(), MAX_DECRYPT_SIZE); - PolynomialT ax(ctxt[ctxt.numPoly() - 1]); - ax.setSize(preset, num_polyunit); - + PolynomialT &ax = ctxt[ctxt.numPoly() - 1]; + ax.setSize(ctxt.preset(), num_polyunit); if (!ax[0].isNTT()) { - forwardNTT(modarith, ax); + forwardNTT(modarith, ax, 0, ntt_type); } for (Size i = 0; i < num_secret; ++i) { CiphertextT ctxt_tmp(ctxt, i); ctxt_tmp.setNumPolyunit(num_polyunit); for (Size j = 0; j < ctxt_tmp.numPoly(); ++j) { if (!ctxt_tmp[j][0].isNTT()) { - forwardNTT(modarith, ctxt_tmp[j]); + forwardNTT(modarith, ctxt_tmp[j], 0, ntt_type); } } - if constexpr (std::is_same_v || - std::is_same_v) { - PolynomialT ptxt_tmp = innerDecrypt(ctxt_tmp, sk[i], ax); - decode(ptxt_tmp, msg[i], scale); - } else if constexpr (std::is_same_v || - std::is_same_v) { - PolynomialT ptxt_tmp = innerDecrypt(ctxt_tmp, sk[i], ax); - innerDecode(ptxt_tmp, msg[i], scale); - } else { - throw std::runtime_error( - "[Decryptor::decrypt] Unsupported message type"); - } + PolynomialT ptxt_tmp = innerDecrypt(ctxt_tmp, sk[i], ax); + decode(ptxt_tmp, msg[i], scale, is_real); } utils::unsetOmpThreadLimit(); } @@ -127,7 +120,7 @@ DecryptorT::innerDecrypt(const CiphertextT &ctxt, const std::optional> &ax) const { PolynomialT ptxt(preset, std::min(ctxt[0].size(), MAX_DECRYPT_SIZE)); for (u64 i = 0; i < ptxt.size(); ++i) { - ptxt[i].setNTT(ctxt[0][i].isNTT()); + ptxt[i].setNTT(ctxt[0][i].getNTTType(), ctxt[0][i].getNTTRootType()); } // m = c_0 + (c_1 + ... + (c_{n-1} + c_n * s) * s ... ) * s u64 last_idx = ctxt.numPoly() - 1; @@ -160,28 +153,87 @@ void DecryptorT::innerDecode(const PolynomialT &ptxt, CMSG &coeff, template template -void DecryptorT::decode(const PolynomialT &ptxt, MSG &msg, - Real scale) const { +void DecryptorT::decode(const PolynomialT &ptxt, MSG &msg, Real scale, + bool is_real) const { deb_assert(msg.size() >= num_slots, "[Decryptor::decode] Message size is too small"); - if constexpr (std::is_same_v || - std::is_same_v) { + if (is_real) { + deb_assert( + msg.size() == degree, + "[Decryptor::decode] For real encoding, message size must be " + "equal to polynomial degree"); CoeffMessage coeff(preset); innerDecode(ptxt, coeff, scale); - - const auto half_degree = num_slots; - for (Size i = 0; i < msg.size(); ++i) { - msg[i].real(coeff[i]); - msg[i].imag(coeff[i + half_degree]); + if constexpr (std::is_same_v) { + msg[0] = {coeff[0], 0}; + for (Size i = 1; i < degree; ++i) { + msg[i].real(coeff[i]); + msg[i].imag(-coeff[degree - i]); + } + fft_.forwardFFT(msg); + } else if constexpr (std::is_same_v) { + msg[0] = {static_cast(coeff[0]), 0}; + for (Size i = 1; i < degree; ++i) { + msg[i].real(static_cast(coeff[i])); + msg[i].imag(static_cast(-coeff[degree - i])); + } + fft_.forwardFFT(msg); + } else if constexpr (std::is_same_v || + std::is_same_v) { + Message msg_tmp(degree); + msg_tmp[0] = {coeff[0], 0}; + for (Size i = 1; i < degree; ++i) { + msg_tmp[i].real(coeff[i]); + msg_tmp[i].imag(-coeff[degree - i]); + } + fft_.forwardFFT(msg_tmp); + for (Size i = 0; i < msg.size(); ++i) { + if constexpr (std::is_same_v) { + msg[i] = msg_tmp[i].real(); + } else if constexpr (std::is_same_v) { + msg[i] = static_cast(msg_tmp[i].real()); + } + } + } else { + throw std::runtime_error( + "[Decryptor::decode] Unsupported message type"); } - fft_.forwardFFT(msg); } else { - throw std::runtime_error( - "[Decryptor::decode] Unsupported message type"); + if constexpr (std::is_same_v || + std::is_same_v) { + CoeffMessage coeff(preset); + innerDecode(ptxt, coeff, scale); + const auto half_degree = num_slots; + using ScalarT = + typename std::remove_reference_t::value_type; + for (Size i = 0; i < msg.size(); ++i) { + msg[i].real(static_cast(coeff[i])); + msg[i].imag(static_cast(coeff[i + half_degree])); + } + fft_.forwardFFT(msg); + } else if constexpr (std::is_same_v || + std::is_same_v) { + innerDecode(ptxt, msg, scale); + } else { + throw std::runtime_error( + "[Decryptor::decode] Unsupported message type"); + } + } +} + +template +void DecryptorT::changeNTTRootType(utils::NTTRootType root_type) { + for (Size i = 0; i < num_p; ++i) { + modarith[i].setNTTRootType(root_type); } } +template +utils::NTTRootType DecryptorT::getNTTRootType() const { + return modarith[0].getNTTRootType(); +} + template template void DecryptorT::decodeWithSinglePoly(const PolynomialT &ptxt, @@ -198,7 +250,7 @@ void DecryptorT::decodeWithSinglePoly(const PolynomialT &ptxt, U *interim = ptxt[0].data(); if (ptxt[0].isNTT()) { - modarith[0].backwardNTT(interim); + modarith[0].backwardNTT(interim, ptxt[0].getNTTType()); } Real tmp; @@ -227,19 +279,19 @@ void DecryptorT::decodeWithPolyPair(const PolynomialT &ptxt, deb_assert(coeff.size() >= ptxt_degree, "[Decryptor::decodeWithPolyPair] Coeff size is too small"); - const auto prime0 = primes[0]; - const auto prime1 = primes[1]; + const U prime0 = static_cast(primes[0]); + const U prime1 = static_cast(primes[1]); const utils::u128 prod_prime = utils::mul64To128(prime0, prime1); const utils::u128 half_prod_prime = prod_prime >> 1; - const u64 bezout0 = modarith[1].inverse(prime0); - const u64 bezout1 = modarith[0].inverse(prime1); + const U bezout0 = modarith[1].inverse(prime0); + const U bezout1 = modarith[0].inverse(prime1); U *ptxt0 = ptxt[0].data(); U *ptxt1 = ptxt[1].data(); if (ptxt[0].isNTT()) { - modarith[0].backwardNTT(ptxt0); - modarith[1].backwardNTT(ptxt1); + modarith[0].backwardNTT(ptxt0, ptxt[0].getNTTType()); + modarith[1].backwardNTT(ptxt1, ptxt[1].getNTTType()); } modarith[0].constMultInPlace(ptxt0, bezout1); modarith[1].constMultInPlace(ptxt1, bezout0); diff --git a/src/Encryptor.cpp b/src/Encryptor.cpp index de3cc7c..16f77c7 100644 --- a/src/Encryptor.cpp +++ b/src/Encryptor.cpp @@ -32,53 +32,40 @@ namespace deb { template EncryptorT::EncryptorT(std::optional seeds) - : PresetTraits(preset), ptxt_buffer_(preset, num_p * num_secret), - vx_buffer_(preset, true), ex_buffer_(preset, true), - samples_(buffer_size(degree)), mask_(degree), i_samples_(degree), - fft_(degree) { + : EncryptorT(P, seeds) { if constexpr (P == PRESET_EMPTY) { throw std::runtime_error( "[Encryptor] Preset template must be specified when using this " "constructor"); } - - for (Size i = 0; i < num_p; ++i) { - modarith.emplace_back(primes[i]); - } - - if (!seeds) { - seeds.emplace(SeedGenerator::Gen()); - } - rng_ = createRandomGenerator(seeds.value()); } template EncryptorT::EncryptorT(Preset target_preset, std::optional seeds) : PresetTraits(target_preset), + rng_(createRandomGenerator(seeds.value_or(SeedGenerator::Gen()))), ptxt_buffer_(target_preset, num_p * num_secret), vx_buffer_(target_preset, true), ex_buffer_(target_preset, true), - samples_(buffer_size(degree)), mask_(degree), i_samples_(degree), - fft_(degree) { - + mask_(degree), samples_(buffer_size(degree)), i_samples_(degree), + fft_(degree << 1) { + if (target_preset == PRESET_EMPTY) { + throw std::runtime_error("[Encryptor] Target preset must be specified " + "when PRESET_EMPTY template is used"); + } for (Size i = 0; i < num_p; ++i) { modarith.emplace_back(degree, primes[i]); } - - if (!seeds) { - seeds.emplace(SeedGenerator::Gen()); - } - rng_ = createRandomGenerator(seeds.value()); } template EncryptorT::EncryptorT(Preset target_preset, std::shared_ptr rng) - : PresetTraits(target_preset), + : PresetTraits(target_preset), rng_(std::move(rng)), ptxt_buffer_(target_preset, num_p * num_secret), vx_buffer_(target_preset, true), ex_buffer_(target_preset, true), - samples_(degree + (sizeof(u64) / sizeof(U)) * div_ceil_32(degree)), - mask_(degree), i_samples_(degree), rng_(std::move(rng)), fft_(degree) { + mask_(degree), samples_(buffer_size(degree)), i_samples_(degree), + fft_(degree << 1) { for (Size i = 0; i < num_p; ++i) { modarith.emplace_back(degree, primes[i]); @@ -137,12 +124,11 @@ void EncryptorT::encrypt(const MSG *msg, const KEY &key, for (Size i = 0; i < num_secret; ++i) { PolynomialT ptxt_tmp(ptxt, single_num_polyunit * i, single_num_polyunit); - encode(msg[i], ptxt_tmp, single_num_polyunit, opt.scale); + encode(msg[i], ptxt_tmp, single_num_polyunit, opt); } } else { - encode(msg[0], ptxt, single_num_polyunit, opt.scale); + encode(msg[0], ptxt, single_num_polyunit, opt); } - innerEncrypt(ptxt, key, single_num_polyunit, ctxt); if constexpr (std::is_same_v || std::is_same_v) { @@ -154,10 +140,15 @@ void EncryptorT::encrypt(const MSG *msg, const KEY &key, throw std::runtime_error( "[Encryptor::encrypt] Unsupported message type"); } + if (opt.real_encrypt) { + ctxt.setEncoding(REAL); + } + innerEncrypt(ptxt, key, single_num_polyunit, ctxt); if (!opt.ntt_out) { for (u64 i = 0; i < ctxt.numPoly(); ++i) { - backwardNTT(modarith, ctxt[i]); + backwardNTT(modarith, ctxt[i], single_num_polyunit, + ctxt[i][0].getNTTType()); } } utils::unsetOmpThreadLimit(); @@ -175,8 +166,11 @@ void EncryptorT::innerEncrypt(const PolynomialT &ptxt, const KEY &key, "[Encryptor::innerEncrypt] Rank must be 1 or NumSecret must be " "1"); bool isNTT = ptxt[0].isNTT(); + bool cyclicNTT = ctxt.encoding() == REAL; + utils::NTTType ntt_type = + cyclicNTT ? utils::NTTType::CYCLIC : utils::NTTType::NEGACYCLIC; ctxt.setNumPolyunit(num_polyunit); - ctxt.setNTT(true); + ctxt.setNTT(ntt_type, modarith[0].getNTT(ntt_type)->getRootType()); if constexpr (std::is_same_v>) { deb_assert(key.numPoly() == num_secret * rank, @@ -206,13 +200,15 @@ void EncryptorT::innerEncrypt(const PolynomialT &ptxt, const KEY &key, PRAGMA_OMP(omp parallel) { for (Size i = 0; i < num_secret; ++i) { - sampleGaussian(num_polyunit, isNTT); + sampleGaussian(num_polyunit, + isNTT ? ntt_type : utils::NTTType::NONNTT); // e = e + m addPoly(modarith, ex_buffer_, ptxt_vec[i], ex_buffer_, num_polyunit); // perform delayed NTT if (!isNTT) { - forwardNTT(modarith, ex_buffer_, num_polyunit); + forwardNTT(modarith, ex_buffer_, num_polyunit, + ntt_type); } mulPolyConst(modarith, ctxt[num_secret], key[i], ctxt[i]); subPoly(modarith, ex_buffer_, ctxt[i], ctxt[i]); @@ -223,14 +219,15 @@ void EncryptorT::innerEncrypt(const PolynomialT &ptxt, const KEY &key, PolynomialT tmp(preset, num_polyunit); PRAGMA_OMP(omp parallel) { - sampleGaussian(num_polyunit, isNTT); + sampleGaussian(num_polyunit, + isNTT ? ntt_type : utils::NTTType::NONNTT); // e = e + m addPoly(modarith, ex_buffer_, ptxt, ex_buffer_, num_polyunit); // perform delayed NTT if (!isNTT) { - forwardNTT(modarith, ex_buffer_, num_polyunit); + forwardNTT(modarith, ex_buffer_, num_polyunit, ntt_type); } // TODO: not tested yet since no preset of rank > 1 // b = - \sigma a_i * s_i + e + m @@ -251,28 +248,56 @@ void EncryptorT::innerEncrypt(const PolynomialT &ptxt, const KEY &key, } PRAGMA_OMP(omp parallel) { - sampleZO(num_polyunit); - sampleGaussian(num_polyunit, true); + sampleZO(num_polyunit, ntt_type); + sampleGaussian(num_polyunit, ntt_type); mulPolyConst(modarith, vx_buffer_, key.ax(0), ctxt[num_secret], num_polyunit); addPoly(modarith, ctxt[num_secret], ex_buffer_, ctxt[num_secret]); for (Size i = 0; i < num_secret; ++i) { - sampleGaussian(num_polyunit, isNTT); + sampleGaussian(num_polyunit, + isNTT ? ntt_type : utils::NTTType::NONNTT); mulPoly(modarith, vx_buffer_, key.bx(i), ctxt[i], num_polyunit); addPoly(modarith, ex_buffer_, ptxt_vec[i], ex_buffer_, num_polyunit); if (!isNTT) { - forwardNTT(modarith, ex_buffer_, num_polyunit); + forwardNTT(modarith, ex_buffer_, num_polyunit, + ntt_type); } addPoly(modarith, ctxt[i], ex_buffer_, ctxt[i]); } } } else { - // not implemented yet + // ctxt[0] = v * bx(0) + e_0 + m (b component) + // ctxt[1..r-1] = 0 (unused, zeroed) + // ctxt[rank] = v * ax(0) + e_rank (a component) + for (Size k = 1; k < rank; ++k) { + ctxt[k].setNTT(ntt_type, + modarith[0].getNTT(ntt_type)->getRootType()); + for (Size i = 0; i < num_polyunit; ++i) { + std::fill_n(ctxt[k][i].data(), degree, U(0)); + } + } + + PRAGMA_OMP(omp parallel) { + sampleZO(num_polyunit, ntt_type); + sampleGaussian(num_polyunit, ntt_type); + mulPolyConst(modarith, vx_buffer_, key.ax(0), ctxt[rank], + num_polyunit); + addPoly(modarith, ctxt[rank], ex_buffer_, ctxt[rank]); + + sampleGaussian(num_polyunit, + isNTT ? ntt_type : utils::NTTType::NONNTT); + mulPoly(modarith, vx_buffer_, key.bx(0), ctxt[0], num_polyunit); + addPoly(modarith, ex_buffer_, ptxt, ex_buffer_, num_polyunit); + if (!isNTT) { + forwardNTT(modarith, ex_buffer_, num_polyunit, ntt_type); + } + addPoly(modarith, ctxt[0], ex_buffer_, ctxt[0]); + } } } else { throw std::runtime_error( @@ -294,7 +319,7 @@ void EncryptorT::innerEncode(const MSG &msg, const Real &delta, ((std::is_same_v) ? 2 : 1)); for (Size i = 0; i < size; i++) { - ptxt[i].setNTT(false); + ptxt[i].setNTT(utils::NTTType::NONNTT); if (degree > msg_size * ((std::is_same_v) ? 2 : 1)) std::fill_n(ptxt[i].data(), degree, U(0)); } @@ -339,40 +364,83 @@ void EncryptorT::innerEncode(const MSG &msg, const Real &delta, template template void EncryptorT::encode(const MSG &msg, PolynomialT &ptxt, - const Size size, const Real scale) const { - const Real delta{scale == 0 ? std::pow(static_cast(2), - scale_factors[ptxt.size() - 1]) - : scale}; - if constexpr (std::is_same_v || - std::is_same_v) { - innerEncode(msg, delta, ptxt, size); - } else if constexpr (std::is_same_v) { - Message tmp(msg.size(), msg.data()); - fft_.backwardFFT(tmp); - innerEncode(tmp, delta, ptxt, size); - } else if constexpr (std::is_same_v) { - Message tmp(msg.size()); - for (Size i = 0; i < msg.size(); ++i) { - tmp[i] = ComplexT(static_cast(msg[i].real()), - static_cast(msg[i].imag())); + const Size size, + const EncryptOptions &opt) const { + const Real delta{opt.scale == 0 ? std::pow(static_cast(2), + scale_factors[ptxt.size() - 1]) + : opt.scale}; + + if (opt.real_encrypt) { + deb_assert(msg.size() == degree, + "[Encryptor::encode] For real encryption, message size must " + "match polynomial degree"); + Message tmp(degree); + if constexpr (std::is_same_v || + std::is_same_v) { + for (Size i = 0; i < tmp.size(); ++i) { + tmp[i] = ComplexT(msg[i], 0); + } + } else if constexpr (std::is_same_v) { + std::copy_n(msg.data(), msg.size(), tmp.data()); + } else if constexpr (std::is_same_v) { + for (Size i = 0; i < tmp.size(); ++i) { + tmp[i] = ComplexT(static_cast(msg[i].real()), 0); + } + } else { + throw std::runtime_error( + "[Encryptor::encode] Unsupported message type"); } fft_.backwardFFT(tmp); - innerEncode(tmp, delta, ptxt, size); + CoeffMessage tmp_coeff(degree); + for (Size i = 0; i < tmp.size(); ++i) { + tmp_coeff[i] = static_cast(tmp[i].real()); + } + innerEncode(tmp_coeff, delta, ptxt, size); } else { - throw std::runtime_error( - "[Encryptor::encode] Unsupported message type"); + if constexpr (std::is_same_v || + std::is_same_v) { + innerEncode(msg, delta, ptxt, size); + } else if constexpr (std::is_same_v) { + Message tmp(msg.size(), msg.data()); + fft_.backwardFFT(tmp); + innerEncode(tmp, delta, ptxt, size); + } else if constexpr (std::is_same_v) { + Message tmp(msg.size()); + for (Size i = 0; i < msg.size(); ++i) { + tmp[i] = ComplexT(static_cast(msg[i].real()), + static_cast(msg[i].imag())); + } + fft_.backwardFFT(tmp); + innerEncode(tmp, delta, ptxt, size); + } else { + throw std::runtime_error( + "[Encryptor::encode] Unsupported message type"); + } } } template -void EncryptorT::sampleZO(Size num_polyunit) const { +void EncryptorT::changeNTTRootType(utils::NTTRootType root_type) { + for (Size i = 0; i < num_p; ++i) { + modarith[i].setNTTRootType(root_type); + } +} + +template +utils::NTTRootType EncryptorT::getNTTRootType() const { + return modarith[0].getNTTRootType(); +} + +template +void EncryptorT::sampleZO(Size num_polyunit, + const utils::NTTType ntt_type) const { // We sample 64 bits at a time and use 2 bits for each coefficient to sample // Assume degree is larger than 32. const auto sample_size = div_ceil_32(degree); PRAGMA_OMP(omp single) { - vx_buffer_.setNTT(false); + vx_buffer_.setNTT(utils::NTTType::NONNTT); // Since this method is in a OMP parallel region, we cannot make local // array So sample data into the end of class variable samples_ rng_->getRandomUint64Array( @@ -394,23 +462,25 @@ void EncryptorT::sampleZO(Size num_polyunit) const { for (Size i = 0; i < num_polyunit; ++i) { for (Size j = 0; j < degree; ++j) { const U mask = mask_[j]; - const U bit = samples_[j]; + const U bit = static_cast(samples_[j]); vx_buffer_[i][j] = (bit & mask) | (static_cast(primes[i] - bit) & ~mask); } } - forwardNTT(modarith, vx_buffer_, num_polyunit); + if (ntt_type != utils::NTTType::NONNTT) { + forwardNTT(modarith, vx_buffer_, num_polyunit, ntt_type); + } } template void EncryptorT::sampleGaussian(const Size num_polyunit, - const bool do_ntt) const { - + const utils::NTTType ntt_type) const { + // Initialize the buffer with Gaussian samples in standard representation PRAGMA_OMP(omp single) { rng_->sampleGaussianInt64Array(i_samples_.data(), degree, gaussian_error_stdev); - ex_buffer_.setNTT(false); + ex_buffer_.setNTT(utils::NTTType::NONNTT); } PRAGMA_OMP(omp for collapse(2) schedule(static)) @@ -428,8 +498,8 @@ void EncryptorT::sampleGaussian(const Size num_polyunit, } } - if (do_ntt) { - forwardNTT(modarith, ex_buffer_, num_polyunit); + if (ntt_type != utils::NTTType::NONNTT) { + forwardNTT(modarith, ex_buffer_, num_polyunit, ntt_type); } } diff --git a/src/KeyGenerator.cpp b/src/KeyGenerator.cpp index eadf805..d1da266 100644 --- a/src/KeyGenerator.cpp +++ b/src/KeyGenerator.cpp @@ -34,9 +34,9 @@ inline void checkSecretKey([[maybe_unused]] const deb::Preset preset, }; template -inline void checkSwk([[maybe_unused]] const deb::Preset &preset, +inline void checkSwk([[maybe_unused]] const deb::Preset preset, [[maybe_unused]] const deb::SwitchKeyT &swk, - const deb::SwitchKeyKind expected_type) { + [[maybe_unused]] const deb::SwitchKeyKind expected_type) { deb_assert(preset == swk.preset(), "[KeyGenerator] Preset mismatch between KeyGenerator and " "SwitchingKey."); @@ -45,9 +45,9 @@ inline void checkSwk([[maybe_unused]] const deb::Preset &preset, }; inline void -checkModPackKeyBundleCondition([[maybe_unused]] const deb::Preset &preset, - [[maybe_unused]] const deb::Preset &preset_from, - [[maybe_unused]] const deb::Preset &preset_to) { +checkModPackKeyBundleCondition([[maybe_unused]] const deb::Preset preset, + [[maybe_unused]] const deb::Preset preset_from, + [[maybe_unused]] const deb::Preset preset_to) { [[maybe_unused]] const deb::Size from_degree = get_degree(preset_from); [[maybe_unused]] const deb::Size from_rank = get_rank(preset_from); @@ -93,23 +93,20 @@ template KeyGeneratorT::KeyGeneratorT(std::optional seeds) : KeyGeneratorT(P, std::move(seeds)) { if constexpr (P == PRESET_EMPTY) { - throw std::runtime_error( - "[KeyGenerator] Preset must be specified for EMPTY preset."); + throw std::runtime_error("[KeyGenerator] Preset must be specified when " + "PRESET_EMPTY template is used."); } } template KeyGeneratorT::KeyGeneratorT(const Preset target_preset, std::optional seeds) - : PresetTraits(target_preset), fft_(degree) { + : PresetTraits(target_preset), + rng_(createRandomGenerator(seeds.value_or(SeedGenerator::Gen()))), + fft_(degree) { for (u64 i = 0; i < num_p; ++i) { modarith.emplace_back(degree, primes[i]); } - if (!seeds) { - seeds.emplace(SeedGenerator::Gen()); - } - rng_ = createRandomGenerator(seeds.value()); - computeConst(); } @@ -126,7 +123,8 @@ KeyGeneratorT::KeyGeneratorT(const Preset target_preset, template void KeyGeneratorT::genSwitchingKey( const PolynomialT *from, const PolynomialT *to, PolynomialT *ax, - PolynomialT *bx, const Size ax_size, const Size bx_size) const { + PolynomialT *bx, const Size ax_size, const Size bx_size, + const utils::NTTType ntt_type) const { const Size length = num_base + num_qp; const Size max_length = num_p; const Size dnum = gadget_rank; @@ -145,7 +143,8 @@ void KeyGeneratorT::genSwitchingKey( const auto &a = ax[idx]; for (Size sid = 0; sid < s_size; ++sid) { auto &b = bx[idx + sid * a_size]; - auto ex = sampleGaussian(max_length, true); + auto ex = sampleGaussian(max_length); + forwardNTT(modarith, ex, 0, ntt_type); mulPoly(modarith, a, to[sid], b); subPoly(modarith, ex, b, b); @@ -186,14 +185,14 @@ void KeyGeneratorT::genEncKeyInplace(SwitchKeyT &enckey, const SecretKeyT &sk) const { checkSecretKey(preset, sk); checkSwk(preset, enckey, SWK_ENC); - const bool ntt_state = true; // currently only support ntt state keys const Size num_poly = num_p; deb_assert(enckey.bxSize() == num_secret && enckey.axSize() == 1, "[KeyGenerator::genEncKeyInplace] " "The provided switching key has invalid size."); sampleUniform(enckey.ax()); - auto ex = sampleGaussian(num_poly, ntt_state); + auto ex = sampleGaussian(num_poly); + forwardNTT(modarith, ex, 0, sk[0][0].getNTTType()); for (Size i = 0; i < num_secret; ++i) { mulPoly(modarith, enckey.ax(), sk[i], enckey.bx(i)); @@ -213,7 +212,6 @@ void KeyGeneratorT::genMultKeyInplace(SwitchKeyT &mulkey, const SecretKeyT &sk) const { checkSecretKey(preset, sk); checkSwk(preset, mulkey, SWK_MULT); - const bool ntt_state = true; // currently only support ntt state keys const Size max_length = num_p; deb_assert(mulkey.bxSize() == num_secret * mulkey.dnum() && mulkey.axSize() == mulkey.dnum(), @@ -223,12 +221,11 @@ void KeyGeneratorT::genMultKeyInplace(SwitchKeyT &mulkey, std::vector> sx2; for (Size i = 0; i < num_secret; ++i) { sx2.emplace_back(preset, max_length); - sx2[i].setNTT(ntt_state); - + sx2[i].setNTT(sk[0][0].getNTTType(), sk[0][0].getNTTRootType()); mulPoly(modarith, sk[i], sk[i], sx2[i]); } genSwitchingKey(sx2.data(), sk.data(), mulkey.getAx().data(), - mulkey.getBx().data()); + mulkey.getBx().data(), 0, 0, sk[0][0].getNTTType()); for (Size i = 0; i < sx2.size(); ++i) { for (Size j = 0; j < sx2[i].size(); ++j) { deb_secure_zero(sx2[i][j].data(), sx2[i][j].degree() * sizeof(U)); @@ -248,7 +245,6 @@ void KeyGeneratorT::genConjKeyInplace(SwitchKeyT &conjkey, const SecretKeyT &sk) const { checkSecretKey(preset, sk); checkSwk(preset, conjkey, SWK_CONJ); - const bool ntt_state = sk[0][0].isNTT(); const Size max_length = num_p; deb_assert(conjkey.bxSize() == num_secret * conjkey.dnum() && @@ -259,13 +255,13 @@ void KeyGeneratorT::genConjKeyInplace(SwitchKeyT &conjkey, std::vector> sx; for (Size i = 0; i < num_secret; ++i) { sx.emplace_back(preset, max_length); - sx[i].setNTT(ntt_state); + sx[i].setNTT(sk[0][0].getNTTType(), sk[0][0].getNTTRootType()); // frobenius map in NTT frobeniusMapInNTT(sk[i], -1, sx[i]); } genSwitchingKey(sx.data(), sk.data(), conjkey.getAx().data(), - conjkey.getBx().data()); + conjkey.getBx().data(), 0, 0, sk[0][0].getNTTType()); for (Size i = 0; i < sx.size(); ++i) { for (Size j = 0; j < sx[i].size(); ++j) { deb_secure_zero(sx[i][j].data(), sx[i][j].degree() * sizeof(U)); @@ -290,7 +286,6 @@ void KeyGeneratorT::genLeftRotKeyInplace(const Size rot, checkSwk(preset, rotkey, SWK_ROT); deb_assert(rot < num_slots, "[KeyGenerator::genLeftRotKeyInplace] " "Rotation value exceeds number of slots."); - const auto ntt_state = true; // currently only support ntt state keys const Size max_length = num_p; deb_assert(rotkey.bxSize() == num_secret * rotkey.dnum() && @@ -303,13 +298,12 @@ void KeyGeneratorT::genLeftRotKeyInplace(const Size rot, std::vector> sx; for (Size i = 0; i < num_secret; ++i) { sx.emplace_back(preset, max_length); - sx[i].setNTT(ntt_state); - + sx[i].setNTT(sk[0][0].getNTTType(), sk[0][0].getNTTRootType()); frobeniusMapInNTT(sk[i], static_cast(fft_.getPowerOfFive(rot)), sx[i]); } genSwitchingKey(sx.data(), sk.data(), rotkey.getAx().data(), - rotkey.getBx().data()); + rotkey.getBx().data(), 0, 0, sk[0][0].getNTTType()); for (Size i = 0; i < sx.size(); ++i) { for (Size j = 0; j < sx[i].size(); ++j) { deb_secure_zero(sx[i][j].data(), sx[i][j].degree() * sizeof(U)); @@ -362,10 +356,10 @@ void KeyGeneratorT::genAutoKeyInplace(const Size sig, automorphism(sk.coeffs() + i * degree, coeff_sig.data() + i * degree, sig, degree); } - auto sk_sig = - SecretKeyGeneratorT::GenSecretKeyFromCoeff(preset, coeff_sig.data()); + auto sk_sig = SecretKeyGeneratorT::GenSecretKeyFromCoeff( + preset, coeff_sig.data(), sk[0][0].getNTTType()); genSwitchingKey(sk_sig.data(), sk.data(), autokey.getAx().data(), - autokey.getBx().data()); + autokey.getBx().data(), 0, 0, sk[0][0].getNTTType()); deb_secure_zero(coeff_sig.data(), coeff_sig.size() * sizeof(i8)); // sk_sig.zeroize(); // automatically zeroized when going out of scope } @@ -431,10 +425,10 @@ void KeyGeneratorT::genComposeKeyInplace(const i8 *coeffs, coeffs_embed[i * deg_ratio] = coeffs[i]; } auto sk_from = SecretKeyGeneratorT::GenSecretKeyFromCoeff( - preset, coeffs_embed.data()); + preset, coeffs_embed.data(), sk[0][0].getNTTType()); genSwitchingKey(sk_from.data(), sk.data(), composekey.getAx().data(), - composekey.getBx().data()); + composekey.getBx().data(), 0, 0, sk[0][0].getNTTType()); // sk_from.zeroize(); // automatically zeroized when going out of scope } @@ -496,9 +490,9 @@ void KeyGeneratorT::genDecomposeKeyInplace( coeffs_embed[i * deg_ratio] = coeffs[i]; } auto sk_to = SecretKeyGeneratorT::GenSecretKeyFromCoeff( - preset, coeffs_embed.data()); + preset, coeffs_embed.data(), sk[0][0].getNTTType()); genSwitchingKey(sk.data(), sk_to.data(), decompkey.getAx().data(), - decompkey.getBx().data()); + decompkey.getBx().data(), 0, 0, sk[0][0].getNTTType()); // sk_to.zeroize(); // automatically zeroized when going out of scope } @@ -551,31 +545,32 @@ void KeyGeneratorT::genDecomposeKeyInplace( "Degree mismatch between KeyGenerator and switching key " "preset."); - const Size num_secret = get_num_secret(preset_swk); const Size deg_ratio = get_degree(preset_swk) / coeffs_size; deb_assert(coeffs_size * deg_ratio == degree, "[KeyGenerator::genDecomposeKey] " "The provided secret key has invalid size."); - deb_assert(num_secret == 1, "[KeyGenerator::genDecomposeKey] " - "Decomposition key generation is only " - "supported for single-secret presets."); + deb_assert(get_num_secret(preset_swk) == 1, + "[KeyGenerator::genDecomposeKey] " + "Decomposition key generation is only " + "supported for single-secret presets."); deb_assert(decompkey.bxSize() == decompkey.dnum() && decompkey.axSize() == decompkey.dnum(), "[KeyGenerator::genDecomposeKeyInplace] " "The provided switching key has invalid size."); + const auto ntt_type = sk[0][0].getNTTType(); std::vector coeffs_embed(degree, 0); for (Size i = 0; i < coeffs_size; ++i) { coeffs_embed[i * deg_ratio] = coeffs[i]; } auto sk_to = SecretKeyGeneratorT::GenSecretKeyFromCoeff( - preset_swk, coeffs_embed.data()); - auto sk_from = - SecretKeyGeneratorT::GenSecretKeyFromCoeff(preset_swk, sk.coeffs()); + preset_swk, coeffs_embed.data(), ntt_type); + auto sk_from = SecretKeyGeneratorT::GenSecretKeyFromCoeff( + preset_swk, sk.coeffs(), ntt_type); KeyGeneratorT keygen_swk(preset_swk); keygen_swk.genSwitchingKey(sk_from.data(), sk_to.data(), decompkey.getAx().data(), - decompkey.getBx().data()); + decompkey.getBx().data(), 0, 0, ntt_type); // sk_to.zeroize(); // automatically zeroized when going out of scope // sk_from.zeroize(); // automatically zeroized when going out of scope } @@ -598,7 +593,8 @@ template void KeyGeneratorT::genModPackKeyBundleInplace( const SecretKeyT &sk_from, const SecretKeyT &sk_to, std::vector> &key_bundle) const { - deb_assert(sk_from[0][0].isNTT() == sk_to[0][0].isNTT(), + deb_assert(sk_from[0][0].isNTT() == sk_to[0][0].isNTT() && + sk_from[0][0].getNTTType() == sk_to[0][0].getNTTType(), "[KeyGenerator::genModPackKeyBundle] " "NTT state mismatch between input secret keys."); deb_assert( @@ -630,8 +626,9 @@ void KeyGeneratorT::genModPackKeyBundleInplace( for (u64 k = 0; k < to_deg; ++k) rlwe_coeff[j + to_rank * k] = sk_to_coeff[k + to_deg * j]; - auto sk_to_rlwe = - SecretKeyGeneratorT::GenSecretKeyFromCoeff(preset, rlwe_coeff); + const auto ntt_type = sk_to[0][0].getNTTType(); + auto sk_to_rlwe = SecretKeyGeneratorT::GenSecretKeyFromCoeff( + preset, rlwe_coeff, ntt_type); for (u64 i = 0; i < num_keys; ++i) { // from_deg * (from_rank / num_keys) -> rlwe_deg ; embed and combine @@ -649,11 +646,11 @@ void KeyGeneratorT::genModPackKeyBundleInplace( for (u64 k = 0; k < from_deg; ++k) rlwe_coeff[j + deg_ratio * k] = sk_from_coeff[k + from_deg * (j + to_rank * i)]; - auto sk_from_rlwe = - SecretKeyGeneratorT::GenSecretKeyFromCoeff(preset, rlwe_coeff); + auto sk_from_rlwe = SecretKeyGeneratorT::GenSecretKeyFromCoeff( + preset, rlwe_coeff, ntt_type); genSwitchingKey(sk_from_rlwe.data(), sk_to_rlwe.data(), key_bundle[i].getAx().data(), - key_bundle[i].getBx().data()); + key_bundle[i].getBx().data(), 0, 0, ntt_type); // sk_from_rlwe.zeroize(); // automatically zeroized when going out of // scope } @@ -667,8 +664,10 @@ KeyGeneratorT::genModPackKeyBundle(const Size pad_rank, const SecretKeyT &sk) const { SwitchKeyT modkey(preset, SWK_MODPACK_SELF); const auto max_length = num_p; - modkey.addAx(max_length, pad_rank, true); - modkey.addBx(max_length, pad_rank * num_secret, true); + modkey.addAx(max_length, pad_rank, sk[0][0].getNTTType(), + sk[0][0].getNTTRootType()); + modkey.addBx(max_length, pad_rank * num_secret, sk[0][0].getNTTType(), + sk[0][0].getNTTRootType()); genModPackKeyBundleInplace(pad_rank, modkey, sk); return modkey; } @@ -686,6 +685,7 @@ void KeyGeneratorT::genModPackKeyBundleInplace( "[KeyGenerator::genModPackKeyBundle] The provided switching key " "has invalid size."); + const auto ntt_type = sk[0][0].getNTTType(); for (Size i = 0; i < pad_rank; ++i) { auto *from_coeff = new i8[degree]; std::memset(from_coeff, 0, degree); @@ -694,9 +694,9 @@ void KeyGeneratorT::genModPackKeyBundleInplace( sk.coeffs()[j * pad_rank + pad_rank - 1 - i]; } auto sk_from = SecretKeyGeneratorT::GenSecretKeyFromCoeff( - sk.preset(), from_coeff); + sk.preset(), from_coeff, ntt_type); genSwitchingKey(sk_from.data(), sk.data(), &(modkey.ax(i)), - &(modkey.bx(i)), 1, num_secret); + &(modkey.bx(i)), 1, num_secret, ntt_type); deb_secure_zero(from_coeff, degree * sizeof(i8)); delete[] from_coeff; // sk_from.zeroize(); // automatically zeroized when going out of scope @@ -712,8 +712,6 @@ void KeyGeneratorT::frobeniusMapInNTT(const PolynomialT &op, deb_assert(pow % 2 != 0, "[KeyGenerator::frobeniusMapInNTT] " "Frobenius map power must be odd."); - u64 log_degree = utils::log2floor(static_cast(degree)); - if (pow == 1) { res = op; } else if (pow == -1) { @@ -747,8 +745,8 @@ void KeyGeneratorT::frobeniusMapInNTT(const PolynomialT &op, } template -PolynomialT KeyGeneratorT::sampleGaussian(const Size num_polyunit, - bool do_ntt) const { +PolynomialT +KeyGeneratorT::sampleGaussian(const Size num_polyunit) const { std::vector samples(degree); rng_->sampleGaussianInt64Array(samples.data(), degree, gaussian_error_stdev); @@ -764,9 +762,6 @@ PolynomialT KeyGeneratorT::sampleGaussian(const Size num_polyunit, } } - if (do_ntt) { - forwardNTT(modarith, poly); - } return poly; } diff --git a/src/ModArith.cpp b/src/ModArith.cpp index 479a658..7428c5f 100644 --- a/src/ModArith.cpp +++ b/src/ModArith.cpp @@ -30,26 +30,32 @@ namespace deb::utils { // --------------------------------------------------------------------------- template -ModArith::ModArith(u64 prime) +ModArith::ModArith(u64 prime, bool default_ntt_is_cyclic) : prime_(static_cast(prime)), two_prime_(static_cast(prime << 1)), barrett_expt_(bitWidth(prime) - 1), barrett_ratio_(static_cast( (static_cast(1) << (barrett_expt_ + 63)) / prime)), - default_array_size_(degree), + default_array_size_(D), barrett_ratio_for_u64_(divide128By64Lo(UINT64_C(1), UINT64_C(0), prime)), barrett_ratio_for_u32_( static_cast((static_cast(1) << 32) / prime)), two_to_64_(powModSimple(2, 64, prime)), two_to_64_shoup_(divide128By64Lo(two_to_64_, UINT64_C(0), prime)), - ntt_(std::make_shared>(degree, prime)) { + root_type_(getGlobalNTTRootType()) { if constexpr (D == 1) { throw std::runtime_error("[ModArith] Degree template parameter must be " "non-zero when degree is not specified"); } + if (default_ntt_is_cyclic) { + ensureCyclicNTT(); + } else { + ensureNegacyclicNTT(); + } } template -ModArith::ModArith(Size actual_degree, u64 prime) +ModArith::ModArith(Size actual_degree, u64 prime, + bool default_ntt_is_cyclic) : DegreeTrait(actual_degree), prime_(static_cast(prime)), two_prime_(static_cast(prime << 1)), barrett_expt_(bitWidth(prime) - 1), @@ -61,7 +67,13 @@ ModArith::ModArith(Size actual_degree, u64 prime) static_cast((static_cast(1) << 32) / prime)), two_to_64_(powModSimple(2, 64, prime)), two_to_64_shoup_(divide128By64Lo(two_to_64_, UINT64_C(0), prime)), - ntt_(std::make_shared>(actual_degree, prime)) {} + root_type_(getGlobalNTTRootType()) { + if (default_ntt_is_cyclic) { + ensureCyclicNTT(); + } else { + ensureNegacyclicNTT(); + } +} // --------------------------------------------------------------------------- // constMult @@ -136,31 +148,43 @@ inline void for_each_modarith(const std::vector> &modarith, template void forwardNTT(const std::vector> &modarith, - PolynomialT &poly, Size num_polyunit, + PolynomialT &poly, Size num_polyunit, NTTType ntt_type, [[maybe_unused]] bool expected_ntt_state) { deb_assert(poly[0].isNTT() == expected_ntt_state, "[forwardNTT] NTT state mismatch"); + deb_assert(ntt_type != NTTType::NONNTT, + "[forwardNTT] Invalid NTT type: NONNTT"); num_polyunit = num_polyunit ? num_polyunit : poly.size(); for_each_modarith( - modarith, [](const ModArith &ma, U *p) { ma.forwardNTT(p); }, + modarith, + [ntt_type](const ModArith &ma, U *p) { + ma.forwardNTT(p, ntt_type); + }, num_polyunit, poly); for (Size i = 0; i < num_polyunit; ++i) { - poly[i].setNTT(true); + poly[i].setNTT(modarith[i].getNTT(ntt_type)->getType(), + modarith[i].getNTT(ntt_type)->getRootType()); } } template void backwardNTT(const std::vector> &modarith, - PolynomialT &poly, Size num_polyunit, + PolynomialT &poly, Size num_polyunit, NTTType ntt_type, [[maybe_unused]] bool expected_ntt_state) { deb_assert(poly[0].isNTT() == expected_ntt_state, "[backwardNTT] NTT state mismatch"); + deb_assert( + ntt_type == poly[0].getNTTType(), + "[backwardNTT] NTT type mismatch between ModArith and polynomial"); num_polyunit = num_polyunit ? num_polyunit : poly.size(); for_each_modarith( - modarith, [](const ModArith &ma, U *p) { ma.backwardNTT(p); }, + modarith, + [ntt_type](const ModArith &ma, U *p) { + ma.backwardNTT(p, ntt_type); + }, num_polyunit, poly); for (Size i = 0; i < num_polyunit; ++i) { - poly[i].setNTT(false); + poly[i].setNTT(utils::NTTType::NONNTT); } } @@ -174,7 +198,9 @@ void addPoly(const std::vector> &modarith, PolynomialT &res, Size num_polyunit) { deb_assert(op1[0].isNTT() == op2[0].isNTT(), "[addPoly] operands NTT state mismatch"); - PRAGMA_OMP(omp single) { res.setNTT(op1[0].isNTT()); } + PRAGMA_OMP(omp single) { + res.setNTT(op1[0].getNTTType(), op1[0].getNTTRootType()); + } const auto degree = res[0].degree(); num_polyunit = num_polyunit ? num_polyunit : res.size(); @@ -194,7 +220,9 @@ void addPolyConst(const std::vector> &modarith, PolynomialT &res, Size num_polyunit) { deb_assert(op1[0].isNTT() == op2[0].isNTT(), "[addPoly] operands NTT state mismatch"); - PRAGMA_OMP(omp single) { res.setNTT(op1[0].isNTT()); } + PRAGMA_OMP(omp single) { + res.setNTT(op1[0].getNTTType(), op1[0].getNTTRootType()); + } const auto degree = res[0].degree(); num_polyunit = num_polyunit ? num_polyunit : res.size(); @@ -214,7 +242,9 @@ void subPoly(const std::vector> &modarith, PolynomialT &res, Size num_polyunit) { deb_assert(op1[0].isNTT() == op2[0].isNTT(), "[subPoly] operands NTT state mismatch"); - PRAGMA_OMP(omp single) { res.setNTT(op1[0].isNTT()); } + PRAGMA_OMP(omp single) { + res.setNTT(op1[0].getNTTType(), op1[0].getNTTRootType()); + } const auto degree = res[0].degree(); num_polyunit = num_polyunit ? num_polyunit : res.size(); @@ -237,7 +267,9 @@ void mulPoly(const std::vector> &modarith, PolynomialT &res, Size num_polyunit) { deb_assert(op1[0].isNTT() == op2[0].isNTT(), "[mulPoly] operands NTT state mismatch"); - PRAGMA_OMP(omp single) { res.setNTT(op1[0].isNTT()); } + PRAGMA_OMP(omp single) { + res.setNTT(op1[0].getNTTType(), op1[0].getNTTRootType()); + } const auto degree = res[0].degree(); num_polyunit = num_polyunit ? num_polyunit : res.size(); @@ -275,7 +307,9 @@ void mulPolyConst(const std::vector> &modarith, PolynomialT &res, Size num_polyunit) { deb_assert(op1[0].isNTT() == op2[0].isNTT(), "[mulPoly] operands NTT state mismatch"); - PRAGMA_OMP(omp single) { res.setNTT(op1[0].isNTT()); } + PRAGMA_OMP(omp single) { + res.setNTT(op1[0].getNTTType(), op1[0].getNTTRootType()); + } const auto degree = res[0].degree(); num_polyunit = num_polyunit ? num_polyunit : res.size(); @@ -311,7 +345,9 @@ template void constMulPoly(const std::vector> &modarith, const PolynomialT &op1, const U *op2, PolynomialT &res, Size s_id, Size e_id) { - PRAGMA_OMP(omp single) { res.setNTT(op1[0].isNTT()); } + PRAGMA_OMP(omp single) { + res.setNTT(op1[0].getNTTType(), op1[0].getNTTRootType()); + } PRAGMA_OMP(omp for schedule(static)) for (Size i = s_id; i < e_id; ++i) { diff --git a/src/NTT.cpp b/src/NTT.cpp index 6803e9e..39fd48f 100644 --- a/src/NTT.cpp +++ b/src/NTT.cpp @@ -19,6 +19,10 @@ #include #include +#include +#include +#include +#include #ifdef DEB_OPENMP #include #endif @@ -68,146 +72,118 @@ u64 findPrimitiveRoot(u64 prime) { } // namespace utils -// --------------------------------------------------------------------------- -// Butterfly operations -- templated on U -// --------------------------------------------------------------------------- namespace { -template -static inline void butterfly(U &x, U &y, U w, U ws, U p1, U p2) { - const U ty = mulModLazy(y, w, ws, p1); - x = subIfGE(x, p2); - if constexpr (std::is_same_v) { - u64 x64 = static_cast(x); - u64 ty64 = static_cast(ty); - u64 p2_64 = static_cast(p2); - u64 sum = x64 + ty64; - u64 diff = x64 + p2_64 - ty64; - x = static_cast(subIfGE(sum, p2_64)); - y = static_cast(subIfGE(diff, p2_64)); - } else { - y = static_cast(x + p2 - ty); - x = static_cast(x + ty); - } -} - -template -static inline void butterflyInv(U &x, U &y, U w, U ws, U p1, U p2) { - if constexpr (std::is_same_v) { - u64 x64 = static_cast(subIfGE(x, p2)); - u64 y64 = static_cast(subIfGE(y, p2)); - u64 p2_64 = static_cast(p2); - U tx = static_cast(subIfGE(x64 + y64, p2_64)); - y = mulModLazy(static_cast(subIfGE(x64 + p2_64 - y64, p2_64)), - w, ws, p1); - x = tx; - } else { - const U tx = subIfGE(static_cast(x + y), p2); - y = mulModLazy(static_cast(x + p2 - y), w, ws, p1); - x = tx; - } -} - -} // anonymous namespace - // --------------------------------------------------------------------------- -// NTT constructor +// Root selection helpers (shared between NTT and NTT_C constructors) // --------------------------------------------------------------------------- -template -NTT::NTT(u64 degree, u64 prime) - : prime_(static_cast(prime)), two_prime_(static_cast(prime * 2)), - degree_(degree), psi_rev_(degree_), psi_inv_rev_(degree_), - psi_rev_shoup_(degree_), psi_inv_rev_shoup_(degree_) { +// Direct primitive-root search: finds a primitive (2*num_roots_log_param)-th +// root of unity modulo prime without the min-root selection step. +u64 findPrimitiveRootDirect(u64 prime, u64 degree) { + const u32 log_order = static_cast(log2floor(degree)) + 1u; - const u64 num_roots = degree_; + u64 odd_factor = prime - 1; + u32 max_two_exp = 0; + while ((odd_factor & u64(1)) == 0) { + odd_factor >>= 1; + max_two_exp++; + } - if (prime % (2 * num_roots) != 1) - throw std::runtime_error("[NTT] Not an NTT-friendly prime given."); + if (log_order > max_two_exp) + throw std::runtime_error("[NTT(DIRECT)] findPrimitiveRootDirect: " + "log_order > max_two_exp; " + "prime is not NTT-friendly for this degree."); - if (!isPowerOfTwo(degree_)) - throw std::runtime_error("[NTT] degree must be a power of two."); + const u64 last_pow = u64(1) << (max_two_exp - log_order); + const u64 max_base = std::min((prime - 1) / 2, u64(10000)); - // All construction arithmetic uses u64 for precision; narrowed to U when - // stored in the twiddle-factor vectors. - auto mult_with_barr = [](u64 x, u64 y, u64 y_barr, u64 prime_mod) { - u64 res = mulModLazy(x, y, y_barr, prime_mod); - return subIfGE(res, prime_mod); - }; + for (u64 base = 2; base <= max_base; base += (base == 2 ? 1 : 2)) { + if (base != 3 && (base % 3) == 0) + continue; - u64 psi = utils::findPrimitiveRoot(prime); - psi = powModSimple(psi, (prime - 1) / (2 * num_roots), prime); + u64 psi = powModSimple(base, odd_factor, prime); + if (psi == 1) + continue; - // Find the minimal 2N-th root of unity - u64 psi_square = mulModSimple(psi, psi, prime); - u64 psi_square_barr = divide128By64Lo(psi_square, 0, prime); - u64 min_root = psi; - u64 psi_tmp = psi; - for (u64 i = 0; i < num_roots; ++i) { - psi_tmp = mult_with_barr(psi_tmp, psi_square, psi_square_barr, prime); - if (psi_tmp < min_root) - min_root = psi_tmp; + u32 exp = 0; + u64 psi_pow = psi; + while (psi_pow != prime - 1 && exp < max_two_exp) { + psi_pow = mulModSimple(psi_pow, psi_pow, prime); + exp++; + } + + if (exp == max_two_exp - 1) + return powModSimple(psi, last_pow, prime); } - psi = min_root; - u64 psi_inv = invModSimple(psi, prime); - psi_rev_[0] = U(1); - psi_inv_rev_[0] = U(1); + throw std::runtime_error("[NTT(DIRECT)] findPrimitiveRootDirect: no " + "primitive root found within base search range."); +} - u64 idx = 0; - u64 previdx = 0; - u64 max_digits = log2floor(degree_); - u64 psi_barr = divide128By64Lo(psi, 0, prime); - u64 psi_inv_barr = divide128By64Lo(psi_inv, 0, prime); - for (u64 i = 1; i < degree_; i++) { - idx = bitReverse(static_cast(i), max_digits); - psi_rev_[idx] = static_cast(mult_with_barr( - static_cast(psi_rev_[previdx]), psi, psi_barr, prime)); - psi_inv_rev_[idx] = static_cast( - mult_with_barr(static_cast(psi_inv_rev_[previdx]), psi_inv, - psi_inv_barr, prime)); - previdx = idx; +// Iterates over the (num_roots) primitive root conjugates and returns the +// numerically smallest one. Shared between the negacyclic 2N-th-root and +// cyclic 4N-th-root searches. +u64 selectMinRoot(u64 prime, u64 num_roots, u64 root_seed) { + u64 root_sq = mulModSimple(root_seed, root_seed, prime); + u64 root_sq_barr = divide128By64Lo(root_sq, 0, prime); + u64 min_root = root_seed; + u64 cur = root_seed; + for (u64 i = 0; i < num_roots; ++i) { + cur = mulModLazy(cur, root_sq, root_sq_barr, prime); + cur = subIfGE(cur, prime); + if (cur < min_root) + min_root = cur; } + return min_root; +} - std::vector tmp(degree_); - tmp[0] = psi_inv_rev_[0]; - Size idx2 = 1; - for (u64 m = (degree_ >> 1); m > 0; m >>= 1) { - for (u64 i = 0; i < m; i++) { - tmp[idx2] = psi_inv_rev_[m + i]; - idx2++; - } - } - psi_inv_rev_ = std::move(tmp); +// --------------------------------------------------------------------------- +// Butterfly primitives +// --------------------------------------------------------------------------- - // Compute Shoup precomputed values using computeShoup - for (u64 i = 0; i < degree_; i++) { - psi_rev_shoup_[i] = computeShoup(psi_rev_[i], prime_); - psi_inv_rev_shoup_[i] = computeShoup(psi_inv_rev_[i], prime_); - } +template inline void butterfly(U &x, U &y, U w, U ws, U p1, U p2) { + // Precondition: prime < 2^30 for u32 (4·prime < 2^32), prime < 2^61 for + // u64 (4·prime < 2^63). Forward butterfly emits [0, 4·prime) lazy form; + // the prefix subIfGE on x absorbs that into [0, 2·prime) before the + // additions, so x + ty and x + p2 - ty stay in [0, 4·prime) — fits both + // word widths. The next butterfly's mulModLazy on y just requires op1 + // to fit U, which 4·prime does. computeForward's final canonical-reduce + // pass brings the array back to [0, prime) before computeBackward sees + // it, so butterflyInv's tighter [0, 2·prime) invariant is preserved. + const U ty = mulModLazy(y, w, ws, p1); + x = subIfGE(x, p2); + y = static_cast(x + p2 - ty); + x = static_cast(x + ty); +} - // Variables for last step of backward NTT - degree_inv_ = static_cast(invModSimple(degree_, prime)); - degree_inv_barrett_ = computeShoup(degree_inv_, prime_); - degree_inv_w_ = static_cast( - mulModSimple(static_cast(degree_inv_), - static_cast(psi_inv_rev_[degree_ - 1]), prime)); - degree_inv_w_barrett_ = computeShoup(degree_inv_w_, prime_); +template +inline void butterflyInv(U &x, U &y, U w, U ws, U p1, U p2) { + // Invariant: inputs in [0, p2) for u32 (preserved by butterfly() and + // by butterflyInv() itself), in [0, 2·p2) for u64 (the looser lazy + // form the u64 forward path emits). For u32 with prime < 2^30 this + // gives x + y, x + p2 - y both in [0, 2·p2) = [0, 4·prime) < 2^32 — + // no prefix reduction needed; the u32 and u64 bodies collapse to the + // same code. mulModLazy only requires op1 to fit a u32. + const U tx = subIfGE(static_cast(x + y), p2); + y = mulModLazy(static_cast(x + p2 - y), w, ws, p1); + x = tx; } // --------------------------------------------------------------------------- -// Forward NTT -- single butterfly pass +// Forward / Backward single-step butterfly passes +// +// The twiddle table layout is identical for negacyclic and cyclic NTT +// (only the stored values differ), so these helpers take the table base +// pointers explicitly. // --------------------------------------------------------------------------- template -void NTT::computeForwardNativeSingleStep(U *op, const u64 t) const { - const u64 degree = this->degree_; - const U prime = this->prime_; - const U two_prime = this->two_prime_; - +void forwardSingleStep(U *op, u64 degree, U prime, U two_prime, u64 t, + const U *w_base, const U *ws_base) { const u64 m = (degree >> 1) / t; - const U *w_ptr = psi_rev_.data() + m; - const U *ws_ptr = psi_rev_shoup_.data() + m; + const U *w_ptr = w_base + m; + const U *ws_ptr = ws_base + m; switch (t) { case 1: @@ -303,39 +279,13 @@ void NTT::computeForwardNativeSingleStep(U *op, const u64 t) const { } } -template void NTT::computeForward(U *op) const { - const u64 degree = this->degree_; - - for (u64 t = (degree >> 1); t > 0; t >>= 1) - computeForwardNativeSingleStep(op, t); - - const U prime = this->prime_; - const U two_prime = this->two_prime_; -#if DEB_ALINAS_LEN == 0 - PRAGMA_OMP(omp simd) -#else - PRAGMA_OMP(omp simd aligned(op : DEB_ALINAS_LEN)) -#endif - for (u64 i = 0; i < degree; i++) { - op[i] = subIfGE(op[i], two_prime); - op[i] = subIfGE(op[i], prime); - } -} - -// --------------------------------------------------------------------------- -// Inverse NTT -- single butterfly pass -// --------------------------------------------------------------------------- - template -void NTT::computeBackwardNativeSingleStep(U *op, const u64 t) const { - const u64 degree = this->degree_; - const U prime = this->prime_; - const U two_prime = this->two_prime_; - +void backwardSingleStep(U *op, u64 degree, U prime, U two_prime, u64 t, + const U *w_base, const U *ws_base) { const u64 m = (degree >> 1) / t; const u64 root_idx = 1 + degree - (degree / t); - const U *w_ptr = psi_inv_rev_.data() + root_idx; - const U *ws_ptr = psi_inv_rev_shoup_.data() + root_idx; + const U *w_ptr = w_base + root_idx; + const U *ws_ptr = ws_base + root_idx; switch (t) { case 1: @@ -432,32 +382,21 @@ void NTT::computeBackwardNativeSingleStep(U *op, const u64 t) const { } } -template void NTT::computeBackwardNativeLast(U *op) const { - const u64 degree = this->degree_; - const U prime = this->prime_; - const U two_prime = this->two_prime_; - - const U degree_inv = this->degree_inv_; - const U degree_inv_br = this->degree_inv_barrett_; - const U degree_inv_w = this->degree_inv_w_; - const U degree_inv_w_br = this->degree_inv_w_barrett_; - +template +void backwardLast(U *op, u64 degree, U prime, U two_prime, U deg_inv, + U deg_inv_barr, U deg_inv_w, U deg_inv_w_barr) { auto butterfly_inv_degree = [&](U &x, U &y) { - if constexpr (std::is_same_v) { - u64 x64 = static_cast(subIfGE(x, two_prime)); - u64 y64 = static_cast(subIfGE(y, two_prime)); - u64 p2_64 = static_cast(two_prime); - U tx = static_cast(subIfGE(x64 + y64, p2_64)); - U ty = static_cast(subIfGE(x64 + p2_64 - y64, p2_64)); - x = mulModLazy(tx, degree_inv, degree_inv_br, prime); - y = mulModLazy(ty, degree_inv_w, degree_inv_w_br, prime); - } else { - U tx = static_cast(x + y); - U ty = static_cast(x + two_prime - y); - tx = subIfGE(tx, two_prime); - x = mulModLazy(tx, degree_inv, degree_inv_br, prime); - y = mulModLazy(ty, degree_inv_w, degree_inv_w_br, prime); - } + // Inputs come from the previous backwardSingleStep / butterflyInv, + // which preserves the [0, 2·prime) invariant for u32 (and the looser + // u64 invariant for u64). Either way x + y and x + 2p - y stay + // within the word width when prime < 2^30 (u32) / prime < 2^61 (u64), + // so the u32 and u64 bodies are the same — no prefix reduction and + // no u64 promotion. + U tx = static_cast(x + y); + U ty = static_cast(x + two_prime - y); + tx = subIfGE(tx, two_prime); + x = mulModLazy(tx, deg_inv, deg_inv_barr, prime); + y = mulModLazy(ty, deg_inv_w, deg_inv_w_barr, prime); }; U *x_ptr = op; @@ -476,32 +415,453 @@ template void NTT::computeBackwardNativeLast(U *op) const { } } +// Applies the stack-style inverse-psi reshuffle that iNTT requires. +// In-place transform of psi_inv_rev_ from bit-reversed layout into the +// stage-prefix layout the backward butterfly walks. +template +void reshuffleInversePsi(std::vector &psi_inv_rev, u64 degree) { + std::vector tmp(degree); + tmp[0] = psi_inv_rev[0]; + Size idx = 1; + for (u64 m = (degree >> 1); m > 0; m >>= 1) { + for (u64 i = 0; i < m; ++i) { + tmp[idx] = psi_inv_rev[m + i]; + ++idx; + } + } + psi_inv_rev = std::move(tmp); +} + +inline u64 multBarr(u64 x, u64 y, u64 y_barr, u64 prime_mod) { + u64 res = mulModLazy(x, y, y_barr, prime_mod); + return subIfGE(res, prime_mod); +} + +} // namespace + +// --------------------------------------------------------------------------- +// NTT_base — common state initialization +// --------------------------------------------------------------------------- + +template +NTT_base::NTT_base(u64 degree, u64 prime, NTTType type, + NTTRootType root_type) + : prime_(static_cast(prime)), two_prime_(static_cast(prime * 2)), + degree_(degree), type_(type), root_type_(root_type) { + if (!isPowerOfTwo(degree_)) + throw std::runtime_error("[NTT] degree must be a power of two."); + if (prime % (2 * degree_) != 1) + throw std::runtime_error( + "[NTT] Not an NTT-friendly prime given " + "(prime must satisfy prime ≡ 1 mod 2·degree)."); +} + +// --------------------------------------------------------------------------- +// NTT — negacyclic constructor +// --------------------------------------------------------------------------- + +template +NTT::NTT(u64 degree, u64 prime, NTTRootType root_type) + : NTT_base(degree, prime, NTTType::NEGACYCLIC, root_type), + psi_rev_(degree), psi_inv_rev_(degree), psi_rev_shoup_(degree), + psi_inv_rev_shoup_(degree) { + + const u64 num_roots = degree_; + + degree_inv_ = static_cast(invModSimple(degree_, prime)); + degree_inv_barrett_ = computeShoup(degree_inv_, prime_); + + // Find primitive 2N-th root ψ. + u64 psi; + switch (root_type) { + case NTTRootType::MIN: { + psi = utils::findPrimitiveRoot(prime); + psi = powModSimple(psi, (prime - 1) / (2 * num_roots), prime); + psi = selectMinRoot(prime, num_roots, psi); + break; + } + case NTTRootType::DIRECT: + psi = findPrimitiveRootDirect(prime, num_roots); + break; + case NTTRootType::CUSTOM: + psi = detail::lookupCustomPsi(degree, prime, "NTT(NEGACYCLIC)"); + break; + default: + throw std::runtime_error("[NTT(NEGACYCLIC)] Unknown NTTRootType."); + } + + u64 psi_inv = invModSimple(psi, prime); + psi_rev_[0] = U(1); + psi_inv_rev_[0] = U(1); + + const u64 max_digits = log2floor(degree_); + const u64 psi_barr = divide128By64Lo(psi, 0, prime); + const u64 psi_inv_barr = divide128By64Lo(psi_inv, 0, prime); + u64 idx = 0; + u64 previdx = 0; + for (u64 i = 1; i < degree_; i++) { + idx = bitReverse(static_cast(i), max_digits); + psi_rev_[idx] = static_cast(multBarr( + static_cast(psi_rev_[previdx]), psi, psi_barr, prime)); + psi_inv_rev_[idx] = + static_cast(multBarr(static_cast(psi_inv_rev_[previdx]), + psi_inv, psi_inv_barr, prime)); + previdx = idx; + } + + reshuffleInversePsi(psi_inv_rev_, degree_); + + for (u64 i = 0; i < degree_; i++) { + psi_rev_shoup_[i] = computeShoup(psi_rev_[i], prime_); + psi_inv_rev_shoup_[i] = computeShoup(psi_inv_rev_[i], prime_); + } + + degree_inv_w_ = static_cast( + mulModSimple(static_cast(degree_inv_), + static_cast(psi_inv_rev_[degree_ - 1]), prime)); + degree_inv_w_barrett_ = computeShoup(degree_inv_w_, prime_); +} + +template void NTT::computeForward(U *op) const { + for (u64 t = (degree_ >> 1); t > 0; t >>= 1) + forwardSingleStep(op, degree_, prime_, two_prime_, t, + psi_rev_.data(), psi_rev_shoup_.data()); + +#if DEB_ALINAS_LEN == 0 + PRAGMA_OMP(omp simd) +#else + PRAGMA_OMP(omp simd aligned(op : DEB_ALINAS_LEN)) +#endif + for (u64 i = 0; i < degree_; i++) { + op[i] = subIfGE(op[i], two_prime_); + op[i] = subIfGE(op[i], prime_); + } +} + template void NTT::computeBackward(U *op) const { + const u64 half_degree = degree_ >> 1; + + for (u64 t = 1; t < half_degree; t <<= 1) + backwardSingleStep(op, degree_, prime_, two_prime_, t, + psi_inv_rev_.data(), psi_inv_rev_shoup_.data()); + + backwardLast(op, degree_, prime_, two_prime_, degree_inv_, + degree_inv_barrett_, degree_inv_w_, degree_inv_w_barrett_); + +#if DEB_ALINAS_LEN == 0 + PRAGMA_OMP(omp simd) +#else + PRAGMA_OMP(omp simd aligned(op : DEB_ALINAS_LEN)) +#endif + for (u64 i = 0; i < degree_; i++) + op[i] = subIfGE(op[i], prime_); +} + +// --------------------------------------------------------------------------- +// NTT_C — cyclic constructor +// --------------------------------------------------------------------------- + +template +NTT_C::NTT_C(u64 degree, u64 prime, NTTRootType root_type) + : NTT_base(degree, prime, NTTType::CYCLIC, root_type), psi_rev_(degree), + psi_inv_rev_(degree), psi_rev_shoup_(degree), psi_inv_rev_shoup_(degree) { + + const u64 num_roots_cyc = 2 * degree_; // primitive (4N)-th root + if (prime % (2 * num_roots_cyc) != 1) + throw std::runtime_error( + "[NTT(CYCLIC)] CYCLIC mode requires prime ≡ 1 mod 4·degree " + "(no primitive 4N-th root of unity exists otherwise)."); + + degree_inv_ = static_cast(invModSimple(degree_, prime)); + degree_inv_barrett_ = computeShoup(degree_inv_, prime_); + + // Find primitive 4N-th root ζ. + u64 zeta; + switch (root_type) { + case NTTRootType::MIN: { + zeta = utils::findPrimitiveRoot(prime); + zeta = powModSimple(zeta, (prime - 1) / (2 * num_roots_cyc), prime); + zeta = selectMinRoot(prime, num_roots_cyc, zeta); + break; + } + case NTTRootType::DIRECT: + zeta = findPrimitiveRootDirect(prime, num_roots_cyc); + break; + case NTTRootType::CUSTOM: + zeta = detail::lookupCustomPsi(2 * degree, prime, "NTT(CYCLIC)"); + break; + default: + throw std::runtime_error("[NTT(CYCLIC)] Unknown NTTRootType."); + } + + // Build a full [1, ζ, ζ², …, ζ^{2N-1}] table; we need every 4th entry + // for the layered psi_rev table plus the first N entries for roots_. + std::vector zeta_pow(2 * degree_); + zeta_pow[0] = 1; + if (2 * degree_ > 1) { + zeta_pow[1] = zeta; + const u64 zeta_barr = divide128By64Lo(zeta, 0, prime); + for (u64 i = 2; i < 2 * degree_; ++i) + zeta_pow[i] = multBarr(zeta_pow[i - 1], zeta, zeta_barr, prime); + } + + // CI <-> cyclic ring conversion tables. + roots_.assign(degree_, U(0)); + roots_inv_.assign(degree_ + 1, U(0)); + for (u64 i = 0; i < degree_; ++i) + roots_[i] = static_cast(zeta_pow[i]); + roots_inv_[0] = U(1); + for (u64 i = 1; i <= degree_; ++i) + roots_inv_[i] = static_cast(prime - zeta_pow[2 * degree_ - i]); + roots_shoup_.resize(degree_); + roots_inv_shoup_.resize(degree_ + 1); + for (u64 i = 0; i < degree_; ++i) + roots_shoup_[i] = computeShoup(roots_[i], prime_); + for (u64 i = 0; i <= degree_; ++i) + roots_inv_shoup_[i] = computeShoup(roots_inv_[i], prime_); + + // Build the layered psi_rev table from ω = ζ^4 (primitive N-th root). + psi_rev_[0] = U(1); + psi_inv_rev_[0] = U(1); + const u64 half_deg = degree_ >> 1; + + if (half_deg > 0) { + std::vector psi_half(half_deg); + std::vector psi_inv_half(half_deg); + psi_half[0] = 1; + psi_inv_half[0] = 1; + for (u64 i = 1; i < half_deg; ++i) { + psi_half[i] = zeta_pow[i * 4]; + psi_inv_half[i] = prime - zeta_pow[2 * degree_ - i * 4]; + } + // Bit-reverse in place over length half_deg. + const u64 hd_bits = log2floor(half_deg); + for (u64 i = 0; i < half_deg; ++i) { + u64 j = static_cast( + bitReverse(static_cast(i), static_cast(hd_bits))); + if (j > i) { + std::swap(psi_half[i], psi_half[j]); + std::swap(psi_inv_half[i], psi_inv_half[j]); + } + } + for (u64 i = 0; i < half_deg; ++i) { + psi_rev_[half_deg + i] = static_cast(psi_half[i]); + psi_inv_rev_[half_deg + i] = static_cast(psi_inv_half[i]); + } + for (u64 m = (half_deg >> 1); m != 0; m >>= 1) { + for (u64 i = 0; i < m; ++i) { + psi_rev_[m + i] = static_cast(psi_half[i]); + psi_inv_rev_[m + i] = static_cast(psi_inv_half[i]); + } + } + } + + reshuffleInversePsi(psi_inv_rev_, degree_); + + for (u64 k = 0; k < degree_; ++k) { + psi_rev_shoup_[k] = computeShoup(psi_rev_[k], prime_); + psi_inv_rev_shoup_[k] = computeShoup(psi_inv_rev_[k], prime_); + } + + degree_inv_w_ = static_cast( + mulModSimple(static_cast(degree_inv_), + static_cast(psi_inv_rev_[degree_ - 1]), prime)); + degree_inv_w_barrett_ = computeShoup(degree_inv_w_, prime_); +} + +// Z_q[X + X^{-1}]/ -> Z_q[X]/ (in place). +template void NTT_C::conversion(U *op) const { + const U prime = this->prime_; + const U two_prime = this->two_prime_; + const u64 degree = this->degree_; + + U *op_ptr = op + 1; + U *op_ptr_back = op + degree - 1; + U *res_ptr = op + 1; + U *res_ptr_back = op + degree - 1; + + const U *roots_ptr = roots_.data() + 1; + const U *roots_ptr_back = roots_.data() + degree - 1; + const U *roots_sh_ptr = roots_shoup_.data() + 1; + const U *roots_sh_ptr_back = roots_shoup_.data() + degree - 1; + const U root_const = roots_inv_[degree]; + const U root_const_sh = roots_inv_shoup_[degree]; + + while (op_ptr != op_ptr_back) { + const U op1 = *op_ptr++, op2 = *op_ptr_back--; + U tmp1 = op1 + mulModLazy(op2, root_const, root_const_sh, prime); + U tmp2 = op2 + mulModLazy(op1, root_const, root_const_sh, prime); + tmp1 = subIfGE(tmp1, two_prime); + tmp2 = subIfGE(tmp2, two_prime); + *res_ptr++ = mulModLazy(tmp1, *roots_ptr++, *roots_sh_ptr++, prime); + *res_ptr_back-- = + mulModLazy(tmp2, *roots_ptr_back--, *roots_sh_ptr_back--, prime); + } + const U op1 = *op_ptr; + U tmp = op1 + mulModLazy(op1, root_const, root_const_sh, prime); + tmp = subIfGE(tmp, two_prime); + *res_ptr = mulModLazy(tmp, *roots_ptr, *roots_sh_ptr, prime); +} + +template void NTT_C::inversion(U *op) const { + const U prime = this->prime_; + const U two_prime = this->two_prime_; const u64 degree = this->degree_; - const u64 half_degree = degree >> 1; + + U *op_ptr = op + 1; + U *op_ptr_back = op + degree - 1; + U *res_ptr = op + 1; + U *res_ptr_back = op + degree - 1; + + const U *roots_ptr = roots_.data() + 1; + const U *roots_ptr_back = roots_.data() + degree - 1; + const U *roots_sh_ptr = roots_shoup_.data() + 1; + const U *roots_sh_ptr_back = roots_shoup_.data() + degree - 1; + const U *roots_inv_ptr = roots_inv_.data() + 1; + const U *roots_inv_ptr_back = roots_inv_.data() + degree - 1; + const U *roots_inv_sh_ptr = roots_inv_shoup_.data() + 1; + const U *roots_inv_sh_ptr_back = roots_inv_shoup_.data() + degree - 1; + + while (op_ptr != op_ptr_back) { + const U op1 = *op_ptr++, op2 = *op_ptr_back--; + *res_ptr++ = + mulModLazy(op1, *roots_inv_ptr++, *roots_inv_sh_ptr++, prime) + + mulModLazy(op2, *roots_ptr++, *roots_sh_ptr++, prime); + *res_ptr_back-- = + mulModLazy(op2, *roots_inv_ptr_back--, *roots_inv_sh_ptr_back--, + prime) + + mulModLazy(op1, *roots_ptr_back--, *roots_sh_ptr_back--, prime); + } + *res_ptr = mulModLazy(*op_ptr, *roots_inv_ptr, *roots_inv_sh_ptr, prime) + + mulModLazy(*op_ptr, *roots_ptr, *roots_sh_ptr, prime); + + // Halve (mod prime) each non-zero coefficient. Reduce to canonical [0, p) + // first so the (v+p)/2 branch never leaves the canonical range. + for (u64 i = 1; i < degree; ++i) { + U v = subIfGE(op[i], two_prime); + v = subIfGE(v, prime); + v = (v & U(1)) ? static_cast(v + prime) : v; + op[i] = static_cast(v >> 1); + } +} + +template void NTT_C::computeForward(U *op) const { + conversion(op); + + for (u64 t = (degree_ >> 1); t > 0; t >>= 1) + forwardSingleStep(op, degree_, prime_, two_prime_, t, + psi_rev_.data(), psi_rev_shoup_.data()); + +#if DEB_ALINAS_LEN == 0 + PRAGMA_OMP(omp simd) +#else + PRAGMA_OMP(omp simd aligned(op : DEB_ALINAS_LEN)) +#endif + for (u64 i = 0; i < degree_; i++) { + op[i] = subIfGE(op[i], two_prime_); + op[i] = subIfGE(op[i], prime_); + } +} + +template void NTT_C::computeBackward(U *op) const { + const u64 half_degree = degree_ >> 1; for (u64 t = 1; t < half_degree; t <<= 1) - computeBackwardNativeSingleStep(op, t); + backwardSingleStep(op, degree_, prime_, two_prime_, t, + psi_inv_rev_.data(), psi_inv_rev_shoup_.data()); - computeBackwardNativeLast(op); + backwardLast(op, degree_, prime_, two_prime_, degree_inv_, + degree_inv_barrett_, degree_inv_w_, degree_inv_w_barrett_); - const U prime = this->prime_; #if DEB_ALINAS_LEN == 0 PRAGMA_OMP(omp simd) #else PRAGMA_OMP(omp simd aligned(op : DEB_ALINAS_LEN)) #endif - for (u64 i = 0; i < degree; i++) - op[i] = subIfGE(op[i], prime); + for (u64 i = 0; i < degree_; i++) + op[i] = subIfGE(op[i], prime_); + + inversion(op); +} + +// --------------------------------------------------------------------------- +// Factory +// --------------------------------------------------------------------------- + +template +std::unique_ptr> makeNTT(u64 degree, u64 prime, NTTType type, + NTTRootType root_type) { + switch (type) { + case NTTType::NEGACYCLIC: + return std::make_unique>(degree, prime, root_type); + case NTTType::CYCLIC: + return std::make_unique>(degree, prime, root_type); + default: + throw std::runtime_error( + "[makeNTT] NTTType must be NEGACYCLIC or CYCLIC."); + } } // Explicit instantiations #ifdef DEB_U64 +template class NTT_base; template class NTT; +template class NTT_C; +template std::unique_ptr> makeNTT(u64, u64, NTTType, + NTTRootType); #endif #ifdef DEB_U32 +template class NTT_base; template class NTT; +template class NTT_C; +template std::unique_ptr> makeNTT(u64, u64, NTTType, + NTTRootType); +#endif + +// --------------------------------------------------------------------------- +// NTT factory dispatch +// --------------------------------------------------------------------------- + +namespace { +template struct NTTFactoryStorage { + static inline std::mutex mu; + static inline NTTFactory factory; +}; +} // namespace + +template void setNTTFactory(NTTFactory factory) { + std::lock_guard lock(NTTFactoryStorage::mu); + NTTFactoryStorage::factory = std::move(factory); +} + +template +std::shared_ptr> createNTT(u64 degree, u64 prime, NTTType type, + NTTRootType root_type) { + NTTFactory factory_copy; + { + std::lock_guard lock(NTTFactoryStorage::mu); + factory_copy = NTTFactoryStorage::factory; + } + if (factory_copy) { + if (auto custom = factory_copy(degree, prime, type, root_type)) { + return custom; + } + // Factory returned an empty shared_ptr — fall through to default. + } + return std::shared_ptr>( + makeNTT(degree, prime, type, root_type)); +} + +#ifdef DEB_U32 +template void setNTTFactory(NTTFactory); +template std::shared_ptr> createNTT(u64, u64, NTTType, + NTTRootType); +#endif +#ifdef DEB_U64 +template void setNTTFactory(NTTFactory); +template std::shared_ptr> createNTT(u64, u64, NTTType, + NTTRootType); #endif } // namespace deb::utils diff --git a/src/NTTConfig.cpp b/src/NTTConfig.cpp new file mode 100644 index 0000000..61d54da --- /dev/null +++ b/src/NTTConfig.cpp @@ -0,0 +1,69 @@ +/* + * Copyright 2026 CryptoLab, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "utils/NTTConfig.hpp" + +#include "utils/Basic.hpp" + +#include +#include +#include +#include +#include +#include + +namespace deb::utils { + +namespace { +// Default global root type is MIN, can be overridden by setGlobalNTTRootType(). +std::atomic s_root_type{NTTRootType::MIN}; +std::mutex s_custom_psi_mutex; +std::map, u64> s_custom_psi_registry; +} // namespace + +void setGlobalNTTRootType(NTTRootType type) { + s_root_type.store(type, std::memory_order_relaxed); +} + +NTTRootType getGlobalNTTRootType() { + return s_root_type.load(std::memory_order_relaxed); +} + +void registerCustomPsi(u64 degree, u64 prime, u64 psi) { + if (powModSimple(psi, 2 * degree, prime) != 1) + throw std::invalid_argument("[NTT(CUSTOM)] registerCustomPsi: " + "psi^(2*degree) != 1 mod prime."); + if (powModSimple(psi, degree, prime) == 1) + throw std::invalid_argument("[NTT(CUSTOM)] registerCustomPsi: psi " + "is not a primitive 2*degree-th root " + "(psi^degree == 1)."); + std::lock_guard lock(s_custom_psi_mutex); + s_custom_psi_registry[{degree, prime}] = psi; +} + +namespace detail { +u64 lookupCustomPsi(u64 registry_key_degree, u64 prime, const char *ctx) { + std::lock_guard lock(s_custom_psi_mutex); + auto it = s_custom_psi_registry.find({registry_key_degree, prime}); + if (it == s_custom_psi_registry.end()) + throw std::runtime_error( + std::string("[NTT(CUSTOM)] ") + ctx + + ": no custom root registered. Call registerCustomPsi() first."); + return it->second; +} +} // namespace detail + +} // namespace deb::utils diff --git a/src/SecretKeyGenerator.cpp b/src/SecretKeyGenerator.cpp index 729bf32..9b83a45 100644 --- a/src/SecretKeyGenerator.cpp +++ b/src/SecretKeyGenerator.cpp @@ -23,25 +23,29 @@ SecretKeyGeneratorT::SecretKeyGeneratorT(Preset preset) : preset_(preset) {} template SecretKeyT -SecretKeyGeneratorT::genSecretKey(std::optional seeds) { - return GenSecretKey(preset_, seeds); +SecretKeyGeneratorT::genSecretKey(std::optional seeds, + utils::NTTType ntt_type) { + return GenSecretKey(preset_, seeds, ntt_type); } template void SecretKeyGeneratorT::genSecretKeyInplace( - SecretKeyT &sk, std::optional seeds) { - GenSecretKeyInplace(sk, seeds); + SecretKeyT &sk, std::optional seeds, + utils::NTTType ntt_type) { + GenSecretKeyInplace(sk, seeds, ntt_type); } template -SecretKeyT SecretKeyGeneratorT::genSecretKeyFromCoeff(const i8 *coeffs) { - return GenSecretKeyFromCoeff(preset_, coeffs); +SecretKeyT +SecretKeyGeneratorT::genSecretKeyFromCoeff(const i8 *coeffs, + utils::NTTType ntt_type) { + return GenSecretKeyFromCoeff(preset_, coeffs, ntt_type); } template -void SecretKeyGeneratorT::genSecretKeyFromCoeffInplace(SecretKeyT &sk, - const i8 *coeffs) { - GenSecretKeyFromCoeffInplace(sk, coeffs); +void SecretKeyGeneratorT::genSecretKeyFromCoeffInplace( + SecretKeyT &sk, const i8 *coeffs, utils::NTTType ntt_type) { + GenSecretKeyFromCoeffInplace(sk, coeffs, ntt_type); } template @@ -77,17 +81,19 @@ SecretKeyGeneratorT::GenCoeffInplace(const Preset preset, i8 *coeffs, template SecretKeyT SecretKeyGeneratorT::ComputeEmbedding(const Preset preset, const i8 *coeffs, - std::optional level) { + std::optional level, + utils::NTTType ntt_type) { level = level.value_or(get_num_p(preset) - 1); SecretKeyT sk(preset); sk.allocPolys(level.value() + 1); - ComputeEmbeddingInplace(sk, coeffs); + ComputeEmbeddingInplace(sk, coeffs, ntt_type); return sk; } template void SecretKeyGeneratorT::ComputeEmbeddingInplace(SecretKeyT &sk, - const i8 *coeffs) { + const i8 *coeffs, + utils::NTTType ntt_type) { const auto dim = get_degree(sk.preset()); const auto num_secret = get_num_secret(sk.preset()); const auto rank = get_rank(sk.preset()); @@ -115,9 +121,10 @@ void SecretKeyGeneratorT::ComputeEmbeddingInplace(SecretKeyT &sk, : static_cast(prime_j - static_cast(-c)); } // TODO: reuse NTT object - utils::NTT ntt(dim, prime_j); - ntt.computeForward(sk[i][j].data()); - sk[i][j].setNTT(true); + auto ntt = utils::createNTT(dim, prime_j, ntt_type, + utils::getGlobalNTTRootType()); + ntt->computeForward(sk[i][j].data()); + sk[i][j].setNTT(ntt_type, utils::getGlobalNTTRootType()); } } } @@ -125,36 +132,39 @@ void SecretKeyGeneratorT::ComputeEmbeddingInplace(SecretKeyT &sk, template SecretKeyT SecretKeyGeneratorT::GenSecretKey(Preset preset, - std::optional seeds) { + std::optional seeds, + utils::NTTType ntt_type) { SecretKeyT sk(preset); sk.setSeed(GenCoeffInplace(preset, sk.coeffs(), seeds)); - GenSecretKeyFromCoeffInplace(sk, sk.coeffs()); + GenSecretKeyFromCoeffInplace(sk, sk.coeffs(), ntt_type); return sk; } template void SecretKeyGeneratorT::GenSecretKeyInplace( - SecretKeyT &sk, std::optional seeds) { + SecretKeyT &sk, std::optional seeds, + utils::NTTType ntt_type) { sk.setSeed(GenCoeffInplace(sk.preset(), sk.coeffs(), seeds)); - GenSecretKeyFromCoeffInplace(sk, sk.coeffs()); + GenSecretKeyFromCoeffInplace(sk, sk.coeffs(), ntt_type); } template -SecretKeyT SecretKeyGeneratorT::GenSecretKeyFromCoeff(const Preset preset, - const i8 *coeffs) { +SecretKeyT SecretKeyGeneratorT::GenSecretKeyFromCoeff( + const Preset preset, const i8 *coeffs, utils::NTTType ntt_type) { SecretKeyT sk(preset); - GenSecretKeyFromCoeffInplace(sk, coeffs); + GenSecretKeyFromCoeffInplace(sk, coeffs, ntt_type); return sk; } template -void SecretKeyGeneratorT::GenSecretKeyFromCoeffInplace(SecretKeyT &sk, - const i8 *coeffs) { - ComputeEmbeddingInplace(sk, coeffs); +void SecretKeyGeneratorT::GenSecretKeyFromCoeffInplace( + SecretKeyT &sk, const i8 *coeffs, utils::NTTType ntt_type) { + ComputeEmbeddingInplace(sk, coeffs, ntt_type); } template -void completeSecretKey(SecretKeyT &sk, std::optional level) { +void completeSecretKey(SecretKeyT &sk, std::optional level, + utils::NTTType ntt_type) { const auto rank = get_rank(sk.preset()); const auto num_secret = get_num_secret(sk.preset()); const auto degree = get_degree(sk.preset()); @@ -172,18 +182,20 @@ void completeSecretKey(SecretKeyT &sk, std::optional level) { sk[0].size() != level.value() + 1) { sk.allocPolys(level.value() + 1); } - SecretKeyGeneratorT::ComputeEmbeddingInplace(sk, sk.coeffs()); + SecretKeyGeneratorT::ComputeEmbeddingInplace(sk, sk.coeffs(), ntt_type); } // Explicit instantiations #ifdef DEB_U64 template class SecretKeyGeneratorT; -template void completeSecretKey(SecretKeyT &, std::optional); +template void completeSecretKey(SecretKeyT &, std::optional, + utils::NTTType); #endif #ifdef DEB_U32 template class SecretKeyGeneratorT; -template void completeSecretKey(SecretKeyT &, std::optional); +template void completeSecretKey(SecretKeyT &, std::optional, + utils::NTTType); #endif } // namespace deb diff --git a/src/SeedGenerator.cpp b/src/SeedGenerator.cpp index b426215..5f13f07 100644 --- a/src/SeedGenerator.cpp +++ b/src/SeedGenerator.cpp @@ -44,7 +44,6 @@ SeedGenerator::SeedGenerator(const std::optional &seeds) { ptr[j] = rd(); } } - // seeds.emplace(nseeds); rng_ = createRandomGenerator(nseeds); } else { rng_ = createRandomGenerator(seeds.value()); diff --git a/src/Serialize.cpp b/src/Serialize.cpp index cc058eb..6b66d96 100644 --- a/src/Serialize.cpp +++ b/src/Serialize.cpp @@ -113,14 +113,20 @@ FCoeffMessage deserializeFCoeff(const deb_fb::Coeff32 *coeff) { flatbuffers::Offset serializePolyUnit(flatbuffers::FlatBufferBuilder &builder, const PolyUnit &polyunit) { + // encoding ntt type and root type into a single int + int ntt_info = static_cast(polyunit.getNTTType()) * 10 + + static_cast(polyunit.getNTTRootType()); return deb_fb::CreatePolyUnit( - builder, polyunit.prime(), polyunit.degree(), polyunit.isNTT(), + builder, polyunit.prime(), polyunit.degree(), ntt_info, builder.CreateVector(polyunit.data(), polyunit.degree())); } PolyUnit deserializePolyUnit(const deb_fb::PolyUnit *polyunit) { PolyUnit poly_t(polyunit->prime(), polyunit->degree()); - poly_t.setNTT(polyunit->ntt_state()); + int ntt_info = polyunit->ntt_info(); + // encoding ntt type and root type into a single int + poly_t.setNTT(static_cast(ntt_info / 10), + static_cast(ntt_info % 10)); std::memcpy(poly_t.data(), polyunit->array()->data(), poly_t.degree() * sizeof(u64)); return poly_t; diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 221c41a..d9d009e 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -51,7 +51,7 @@ if(DEB_SUPPORT_U64) target_link_libraries(KeyGen-test PRIVATE ${deb_test}) add_gtest_target_to_ctest(KeyGen-test) - # Ntt Test + # NTT Test add_executable(NTT-test NTT-test.cpp) target_link_libraries(NTT-test PRIVATE ${deb_test}) add_gtest_target_to_ctest(NTT-test) diff --git a/test/EnDecryption-test.cpp b/test/EnDecryption-test.cpp index 8b2bd07..3bc5672 100644 --- a/test/EnDecryption-test.cpp +++ b/test/EnDecryption-test.cpp @@ -16,6 +16,7 @@ #include "DebParam.hpp" #include "TestBase.hpp" +#include "utils/FFT.hpp" #include "utils/OmpUtils.hpp" using namespace deb; @@ -268,8 +269,197 @@ TEST_P(EnDecrypt, ScaleEncryptAndDecryptCoeffWithEncKey) { } } +/*--------------------------------------------------- + Real Encryption Tests +---------------------------------------------------*/ +class RealEnDecrypt : public DebTestBase { +public: + RealEnDecrypt() { + for (auto &p : Presets) { + if (get_preset_name(p) == "FGbD12L0") { + preset = p; + break; + } + } + if (get_preset_name(preset) == "FGbD12L0") { + encryptor = Encryptor(preset); + decryptor = Decryptor(preset); + num_slots = get_num_slots(preset); + degree = get_degree(preset); + num_secret = get_num_secret(preset); + } + } + void SetUp() override { + if (get_preset_name(preset) != "FGbD12L0") { + GTEST_SKIP() + << "No suitable preset found for real encryption tests."; + } + } +}; + +TEST_P(RealEnDecrypt, EncryptAndDecryptWithSecretKey) { + MSGS msg = gen_empty_real_message(); + for (Size s = 0; s < num_secret; ++s) { + for (Size j = 0; j < degree; ++j) { + msg[s][j].real(dist(gen)); + msg[s][j].imag(0.0); + } + } + SecretKey sk = SecretKeyGenerator::GenSecretKey(preset, std::nullopt, + utils::NTTType::CYCLIC); + MSGS decrypted_msg = gen_empty_real_message(); + + for (Size l = 0; l < get_num_p(preset); ++l) { + Ciphertext ctxt(preset, l); + MSGS scaled_msg = scale_complex_message(msg, l); + encryptor.encrypt(scaled_msg, sk, ctxt, + EncryptOptions().Level(l).RealEncrypt(true)); + decryptor.decrypt(ctxt, sk, decrypted_msg); + compare_heaan_msg(scaled_msg, decrypted_msg, scale_error(sk_err, l)); + } +} + +TEST_P(RealEnDecrypt, EncryptAndDecryptFloatWithSecretKey) { + FMSGS msg = gen_empty_real_message(); + for (Size s = 0; s < num_secret; ++s) { + for (Size j = 0; j < degree; ++j) { + msg[s][j].real(static_cast(dist(gen))); + msg[s][j].imag(0.0f); + } + } + SecretKey sk = SecretKeyGenerator::GenSecretKey(preset, std::nullopt, + utils::NTTType::CYCLIC); + FMSGS decrypted_msg = gen_empty_real_message(); + + for (Size l = 0; l < get_num_p(preset); ++l) { + Ciphertext ctxt(preset, l); + FMSGS scaled_msg = scale_complex_message(msg, l); + encryptor.encrypt(scaled_msg, sk, ctxt, + EncryptOptions().Level(l).RealEncrypt(true)); + decryptor.decrypt(ctxt, sk, decrypted_msg); + compare_heaan_msg(scaled_msg, decrypted_msg, scale_error(sk_err_f, l)); + } +} + +TEST_P(RealEnDecrypt, EncryptAndDecryptWithEncKey) { + MSGS msg = gen_empty_real_message(); + for (Size s = 0; s < num_secret; ++s) { + for (Size j = 0; j < degree; ++j) { + msg[s][j].real(dist(gen)); + msg[s][j].imag(0.0); + } + } + SecretKey sk = SecretKeyGenerator::GenSecretKey(preset, std::nullopt, + utils::NTTType::CYCLIC); + KeyGenerator keygen(preset); + SwitchKey enckey = keygen.genEncKey(sk); + MSGS decrypted_msg = gen_empty_real_message(); + + for (Size l = 0; l < get_num_p(preset); ++l) { + Ciphertext ctxt(preset, l); + MSGS scaled_msg = scale_complex_message(msg, l); + encryptor.encrypt(scaled_msg, enckey, ctxt, + EncryptOptions().Level(l).RealEncrypt(true)); + decryptor.decrypt(ctxt, sk, decrypted_msg); + compare_heaan_msg(scaled_msg, decrypted_msg, scale_error(enc_err, l)); + } +} + +TEST_P(RealEnDecrypt, EncryptAndDecryptFloatWithEncKey) { + FMSGS msg = gen_empty_real_message(); + for (Size s = 0; s < num_secret; ++s) { + for (Size j = 0; j < degree; ++j) { + msg[s][j].real(static_cast(dist(gen))); + msg[s][j].imag(0.0f); + } + } + SecretKey sk = SecretKeyGenerator::GenSecretKey(preset, std::nullopt, + utils::NTTType::CYCLIC); + KeyGenerator keygen(preset); + SwitchKey enckey = keygen.genEncKey(sk); + + FMSGS decrypted_msg = gen_empty_real_message(); + + for (Size l = 0; l < get_num_p(preset); ++l) { + Ciphertext ctxt(preset, l); + FMSGS scaled_msg = scale_complex_message(msg, l); + encryptor.encrypt(scaled_msg, enckey, ctxt, + EncryptOptions().Level(l).RealEncrypt(true)); + decryptor.decrypt(ctxt, sk, decrypted_msg); + compare_heaan_msg(scaled_msg, decrypted_msg, scale_error(enc_err_f, l)); + } +} + +TEST_P(RealEnDecrypt, EncryptAndDecryptCoeffWithSecretKey) { + COEFFS coeff = gen_random_coeff(); + SecretKey sk = SecretKeyGenerator::GenSecretKey(preset, std::nullopt, + utils::NTTType::CYCLIC); + COEFFS decrypted_coeff = gen_empty_coeff(); + + for (Size l = 0; l < get_num_p(preset); ++l) { + Ciphertext ctxt(preset, l); + COEFFS scaled_coeff = scale_coeff(coeff, l); + encryptor.encrypt(scaled_coeff, sk, ctxt, + EncryptOptions().Level(l).RealEncrypt(true)); + decryptor.decrypt(ctxt, sk, decrypted_coeff); + compare_coeff(scaled_coeff, decrypted_coeff, scale_error(sk_err_f, l)); + } +} + +TEST_P(RealEnDecrypt, EncryptAndDecryptFloatCoeffWithSecretKey) { + FCOEFFS coeff = gen_random_coeff(); + SecretKey sk = SecretKeyGenerator::GenSecretKey(preset, std::nullopt, + utils::NTTType::CYCLIC); + FCOEFFS decrypted_coeff = gen_empty_coeff(); + + for (Size l = 0; l < get_num_p(preset); ++l) { + Ciphertext ctxt(preset, l); + encryptor.encrypt(coeff, sk, ctxt, + EncryptOptions().Level(l).RealEncrypt(true)); + decryptor.decrypt(ctxt, sk, decrypted_coeff); + compare_coeff(coeff, decrypted_coeff, scale_error(sk_err, l)); + } +} + +TEST_P(RealEnDecrypt, EncryptAndDecryptCoeffWithEncKey) { + COEFFS coeff = gen_random_coeff(); + SecretKey sk = SecretKeyGenerator::GenSecretKey(preset, std::nullopt, + utils::NTTType::CYCLIC); + KeyGenerator keygen(preset); + SwitchKey enckey = keygen.genEncKey(sk); + COEFFS decrypted_coeff = gen_empty_coeff(); + + for (Size l = 0; l < get_num_p(preset); ++l) { + Ciphertext ctxt(preset, l); + encryptor.encrypt(coeff, enckey, ctxt, + EncryptOptions().Level(l).RealEncrypt(true)); + decryptor.decrypt(ctxt, sk, decrypted_coeff); + compare_coeff(coeff, decrypted_coeff, scale_error(enc_err_f, l)); + } +} + +TEST_P(RealEnDecrypt, EncryptAndDecryptFloatCoeffWithEncKey) { + FCOEFFS coeff = gen_random_coeff(); + SecretKey sk = SecretKeyGenerator::GenSecretKey(preset, std::nullopt, + utils::NTTType::CYCLIC); + KeyGenerator keygen(preset); + SwitchKey enckey = keygen.genEncKey(sk); + FCOEFFS decrypted_coeff = gen_empty_coeff(); + + for (Size l = 0; l < get_num_p(preset); ++l) { + Ciphertext ctxt(preset, l); + encryptor.encrypt(coeff, enckey, ctxt, + EncryptOptions().Level(l).RealEncrypt(true)); + decryptor.decrypt(ctxt, sk, decrypted_coeff); + compare_coeff(coeff, decrypted_coeff, scale_error(enc_err_f, l)); + } +} + #define X(PRESET) Preset::PRESET_##PRESET, const std::vector all_presets = {PRESET_LIST #undef X }; INSTANTIATE_TEST_SUITE_P(EnDecrypt, EnDecrypt, testing::ValuesIn(all_presets)); + +INSTANTIATE_TEST_SUITE_P(RealEnDecrypt, RealEnDecrypt, + testing::Values(static_cast(0))); diff --git a/test/NTT-test.cpp b/test/NTT-test.cpp index 9438576..0f6f8c3 100644 --- a/test/NTT-test.cpp +++ b/test/NTT-test.cpp @@ -15,6 +15,7 @@ */ #include "utils/Basic.hpp" +#include "utils/FFT.hpp" #include "utils/NTT.hpp" #include @@ -25,7 +26,7 @@ using namespace deb; using namespace std; -class NttTest : public ::testing::TestWithParam> { +class NTTTest : public ::testing::TestWithParam> { public: const u64 degree{get<0>(GetParam())}; const u64 prime{get<1>(GetParam())}; @@ -71,7 +72,7 @@ inline u64 findMinPrimitiveRoot(u64 degree, u64 prime) { return psi; } -TEST_P(NttTest, SameAfterNTTandiNTT) { +TEST_P(NTTTest, SameAfterNTTandiNTT) { utils::NTT ntt{degree, prime}; auto *op = getRandomVector(); @@ -83,7 +84,7 @@ TEST_P(NttTest, SameAfterNTTandiNTT) { EXPECT_EQ(res, op); } -TEST_P(NttTest, PerformNTTforOneZeroVector) { +TEST_P(NTTTest, PerformNTTforOneZeroVector) { utils::NTT ntt{degree, prime}; std::vector op1(degree); @@ -95,11 +96,756 @@ TEST_P(NttTest, PerformNTTforOneZeroVector) { EXPECT_EQ(op1, op2); } -INSTANTIATE_TEST_SUITE_P(61bitPrimes, NttTest, +INSTANTIATE_TEST_SUITE_P(61bitPrimes, NTTTest, testing::Values(std::tuple{1 << 15, 2305843009146585089})); -INSTANTIATE_TEST_SUITE_P(40bitPrimes, NttTest, +INSTANTIATE_TEST_SUITE_P(40bitPrimes, NTTTest, testing::Values(std::tuple{1 << 13, 2199020634113})); -INSTANTIATE_TEST_SUITE_P(TinyDegree, NttTest, +INSTANTIATE_TEST_SUITE_P(TinyDegree, NTTTest, testing::Values(std::tuple{64, 4295688193})); + +// --------------------------------------------------------------------------- +// NTTRootType::DIRECT +// --------------------------------------------------------------------------- + +TEST_P(NTTTest, DirectSameAfterNTTandiNTT) { + utils::ScopedNTTRootType guard{utils::NTTRootType::DIRECT}; + utils::NTT ntt{degree, prime}; + + auto *op = getRandomVector(); + auto *res = op; + + ntt.computeForward(res); + ntt.computeBackward(res); + + EXPECT_EQ(res, op); +} + +TEST_P(NTTTest, DirectPerformNTTforOneZeroVector) { + utils::ScopedNTTRootType guard{utils::NTTRootType::DIRECT}; + utils::NTT ntt{degree, prime}; + + std::vector op1(degree); + op1[0] = 1; + std::vector op2(degree, 1); + + ntt.computeForward(op1.data()); + + EXPECT_EQ(op1, op2); +} + +INSTANTIATE_TEST_SUITE_P(DirectMode61bit, NTTTest, + testing::Values(std::tuple{1 << 15, + 2305843009146585089})); +INSTANTIATE_TEST_SUITE_P(DirectMode40bit, NTTTest, + testing::Values(std::tuple{1 << 13, 2199020634113})); +INSTANTIATE_TEST_SUITE_P(DirectModeTiny, NTTTest, + testing::Values(std::tuple{64, 4295688193})); + +// --------------------------------------------------------------------------- +// NTTRootType::CUSTOM +// --------------------------------------------------------------------------- + +// Compute a primitive 2*degree-th root of unity by trying small bases. +// Does not depend on findPrimitiveRoot() so avoids the unresolved-symbol issue +// with the nested utils::utils namespace in the library. +inline u64 findTestPsi(u64 degree, u64 prime) { + for (u64 base : {u64(3), u64(5), u64(6), u64(7), u64(11), u64(13)}) { + u64 psi = utils::powModSimple(base, (prime - 1) / (2 * degree), prime); + if (psi != 1 && utils::powModSimple(psi, degree, prime) != 1) + return psi; + } + throw std::runtime_error( + "findTestPsi: no primitive root found for test prime"); +} + +// Compute a primitive 4*degree-th root of unity by trying small bases. +// CYCLIC mode uses a primitive 4N-th root internally (ω = ζ^4 is the +// primitive N-th root that drives the cyclic butterfly), so for CUSTOM +// root_type a 4N-th root must be registered under the key (2*degree, prime). +// Requires prime ≡ 1 mod 4*degree. +inline u64 findTestZeta(u64 degree, u64 prime) { + const u64 four_deg = 4 * degree; + for (u64 base : {u64(3), u64(5), u64(6), u64(7), u64(11), u64(13), u64(17), + u64(19), u64(23)}) { + u64 zeta = utils::powModSimple(base, (prime - 1) / four_deg, prime); + if (zeta != 1 && utils::powModSimple(zeta, 2 * degree, prime) != 1) + return zeta; + } + throw std::runtime_error( + "findTestZeta: no primitive 4N-th root found for test prime"); +} + +// Register the 4N-th root used by the cyclic NTT table. Stored under +// the (2*degree, prime) key so it does not collide with the negacyclic +// 2N-th root registered under (degree, prime). +inline void registerCustomZeta(u64 degree, u64 prime) { + utils::registerCustomPsi(2 * degree, prime, findTestZeta(degree, prime)); +} + +TEST_P(NTTTest, CustomSameAfterNTTandiNTT) { + utils::registerCustomPsi(degree, prime, findTestPsi(degree, prime)); + utils::ScopedNTTRootType guard{utils::NTTRootType::CUSTOM}; + utils::NTT ntt{degree, prime}; + + auto *op = getRandomVector(); + auto *res = op; + + ntt.computeForward(res); + ntt.computeBackward(res); + + EXPECT_EQ(res, op); +} + +TEST_P(NTTTest, CustomPerformNTTforOneZeroVector) { + utils::registerCustomPsi(degree, prime, findTestPsi(degree, prime)); + utils::ScopedNTTRootType guard{utils::NTTRootType::CUSTOM}; + utils::NTT ntt{degree, prime}; + + std::vector op1(degree); + op1[0] = 1; + std::vector op2(degree, 1); + + ntt.computeForward(op1.data()); + + EXPECT_EQ(op1, op2); +} + +INSTANTIATE_TEST_SUITE_P(CustomMode61bit, NTTTest, + testing::Values(std::tuple{1 << 15, + 2305843009146585089})); +INSTANTIATE_TEST_SUITE_P(CustomMode40bit, NTTTest, + testing::Values(std::tuple{1 << 13, 2199020634113})); +INSTANTIATE_TEST_SUITE_P(CustomModeTiny, NTTTest, + testing::Values(std::tuple{64, 4295688193})); + +// ============================================================================ +// Cyclic NTT tests +// ============================================================================ + +class CyclicNTTTest : public ::testing::TestWithParam> { +public: + const u64 degree{get<0>(GetParam())}; + const u64 prime{get<1>(GetParam())}; + std::mt19937_64 gen{std::random_device{}()}; + + vector random_vec() { + vector v(degree); + uniform_int_distribution dist(0, prime - 1); + for (auto &x : v) + x = dist(gen); + return v; + } +}; + +// Cyclic NTT also requires prime ≡ 1 mod 2·degree (same primes as negacyclic). +INSTANTIATE_TEST_SUITE_P(Cyclic61bit, CyclicNTTTest, + testing::Values(std::tuple{1 << 15, + 2305843009146585089ULL})); +INSTANTIATE_TEST_SUITE_P(Cyclic40bit, CyclicNTTTest, + testing::Values(std::tuple{1 << 13, + 2199020634113ULL})); +INSTANTIATE_TEST_SUITE_P(CyclicTiny, CyclicNTTTest, + testing::Values(std::tuple{64, 4295688193ULL})); + +TEST_P(CyclicNTTTest, RoundTrip) { + utils::NTT_C ntt{degree, prime}; + + auto v = random_vec(); + auto result = v; + ntt.computeForward(result.data()); + ntt.computeBackward(result.data()); + EXPECT_EQ(result, v); +} + +TEST_P(CyclicNTTTest, ForwardOfConstantOne) { + // The CI-subring vector (1, 0, …, 0) survives conversion() unchanged + // (op[0] is left alone and the rest are zero), so the layered butterfly + // ends up evaluating f(x) = 1 at every N-th root of unity — every bin is 1. + utils::NTT_C ntt{degree, prime}; + + vector op(degree, 0); + op[0] = 1; + ntt.computeForward(op.data()); + EXPECT_EQ(op, vector(degree, 1)); +} + +TEST_P(CyclicNTTTest, DirectRoundTrip) { + utils::NTT_C ntt{degree, prime, utils::NTTRootType::DIRECT}; + + auto v = random_vec(); + auto r = v; + ntt.computeForward(r.data()); + ntt.computeBackward(r.data()); + EXPECT_EQ(r, v); +} + +TEST_P(CyclicNTTTest, CustomRoundTrip) { + // CYCLIC + CUSTOM looks up a primitive 4N-th root under the key + // (2*degree, prime). + registerCustomZeta(degree, prime); + utils::NTT_C ntt{degree, prime, utils::NTTRootType::CUSTOM}; + + auto v = random_vec(); + auto r = v; + ntt.computeForward(r.data()); + ntt.computeBackward(r.data()); + EXPECT_EQ(r, v); +} + +// ============================================================================ +// NTTRootType feature tests (non-parametrized) +// ============================================================================ + +namespace { + +// Primes and degrees reused across the feature tests below. +// kSmallPrime ≡ 1 mod 128 (NTT-friendly for degree 64) +// kSmallPrime ≡ 1 mod 256 (NTT-friendly for degree 128, used for missing- +// registration test) +constexpr u64 kSmallDegree = 64; +constexpr u64 kSmallPrime = 4295688193ULL; + +// Naive negacyclic convolution: c = a * b mod (X^N + 1) mod p. +vector negacyclicConv(const vector &a, const vector &b, + u64 prime) { + const u64 N = a.size(); + vector c(N, 0); + for (u64 i = 0; i < N; i++) { + for (u64 j = 0; j < N; j++) { + u64 prod = utils::mulModSimple(a[i], b[j], prime); + u64 k = i + j; + if (k < N) + c[k] = (c[k] + prod) % prime; + else + c[k - N] = (c[k - N] + prime - prod) % prime; + } + } + return c; +} + +} // namespace + +// ---------------------------------------------------------------------------- +// Cyclic NTT cross-mode consistency +// +// The hem-compatible cyclic NTT does not compute X^N−1 convolution on raw +// coefficients (it transforms the CI subring of Z_q[X]/ via the +// conversion()/inversion() pair), so there is no simple closed-form +// reference to compare against. What we *can* assert is that the three +// root-finding paths land on equivalent transforms: a round-trip with +// pointwise multiplication in the NTT domain must produce the same +// polynomial regardless of which primitive 4N-th root was selected. +// ---------------------------------------------------------------------------- + +TEST(CyclicNTTPolyMul, AllModesAgreeOnPointwiseProduct) { + constexpr u64 N = kSmallDegree; + constexpr u64 p = kSmallPrime; + + std::mt19937_64 rng(0xcafe1234); + vector a(N), b(N); + for (auto &x : a) + x = rng() % p; + for (auto &x : b) + x = rng() % p; + + auto nttMul = [&](utils::NTTRootType rt) { + utils::NTT_C ntt{N, p, rt}; + vector fa(a), fb(b), fc(N); + ntt.computeForward(fa.data()); + ntt.computeForward(fb.data()); + for (u64 i = 0; i < N; i++) + fc[i] = utils::mulModSimple(fa[i], fb[i], p); + ntt.computeBackward(fc.data()); + return fc; + }; + + registerCustomZeta(N, p); + + const auto min_result = nttMul(utils::NTTRootType::MIN); + const auto direct_result = nttMul(utils::NTTRootType::DIRECT); + const auto custom_result = nttMul(utils::NTTRootType::CUSTOM); + + EXPECT_EQ(min_result, direct_result) + << "CYCLIC MIN and DIRECT modes disagree on pointwise product"; + EXPECT_EQ(min_result, custom_result) + << "CYCLIC MIN and CUSTOM modes disagree on pointwise product"; +} + +// Confirms negacyclic NTT continues to compute X^N+1 convolution, and that +// the cyclic path produces a *different* polynomial (sanity check that the +// cyclic flag actually engages the alternate twiddle table and pre/post +// conversion, not the same code path as negacyclic). +TEST(CyclicNTTPolyMul, CyclicDiffersFromNegacyclic) { + constexpr u64 N = kSmallDegree; + constexpr u64 p = kSmallPrime; + + std::mt19937_64 rng(0xdead5678); + vector a(N), b(N); + for (auto &x : a) + x = rng() % p; + for (auto &x : b) + x = rng() % p; + + auto cycMul = [&]() { + utils::NTT_C ntt{N, p}; + vector fa(a), fb(b), fc(N); + ntt.computeForward(fa.data()); + ntt.computeForward(fb.data()); + for (u64 i = 0; i < N; i++) + fc[i] = utils::mulModSimple(fa[i], fb[i], p); + ntt.computeBackward(fc.data()); + return fc; + }; + auto negMul = [&]() { + utils::NTT ntt{N, p}; + vector fa(a), fb(b), fc(N); + ntt.computeForward(fa.data()); + ntt.computeForward(fb.data()); + for (u64 i = 0; i < N; i++) + fc[i] = utils::mulModSimple(fa[i], fb[i], p); + ntt.computeBackward(fc.data()); + return fc; + }; + + const auto cyc_result = cycMul(); + const auto neg_result = negMul(); + const auto neg_ref = negacyclicConv(a, b, p); + + EXPECT_EQ(neg_result, neg_ref) + << "Negacyclic NTT does not compute X^N+1 convolution"; + EXPECT_NE(cyc_result, neg_result) + << "Cyclic and negacyclic NTT paths should produce different results"; +} + +// ---------------------------------------------------------------------------- +// registerCustomPsi: validation / rejection tests +// ---------------------------------------------------------------------------- + +TEST(NTTRootTypeValidation, RegisterThrowsForTrivialRoot) { + // psi = 1 has order 1; psi^degree = 1 → rejected as not primitive. + EXPECT_THROW(utils::registerCustomPsi(kSmallDegree, kSmallPrime, 1ULL), + std::invalid_argument); +} + +TEST(NTTRootTypeValidation, RegisterThrowsForMinusOne) { + // psi = -1 mod p has order 2. degree (64) is even so psi^degree = 1 + // → rejected (order-2 root is not a primitive 2*degree-th root of unity). + EXPECT_THROW( + utils::registerCustomPsi(kSmallDegree, kSmallPrime, kSmallPrime - 1), + std::invalid_argument); +} + +TEST(NTTRootTypeValidation, RegisterThrowsForNthRootNotPrimitive) { + // psi_2N is a primitive 2N-th root of unity. + // psi_2N^2 is a primitive N-th root (order N, not 2N): + // (psi_2N^2)^N = psi_2N^(2N) = 1 → rejected. + u64 psi_2N = findTestPsi(kSmallDegree, kSmallPrime); + u64 psi_N = utils::mulModSimple(psi_2N, psi_2N, kSmallPrime); + EXPECT_THROW(utils::registerCustomPsi(kSmallDegree, kSmallPrime, psi_N), + std::invalid_argument); +} + +TEST(NTTRootTypeValidation, RegisterThrowsForArbitraryNonRoot) { + // psi = 2 is almost certainly not a 2*kSmallDegree-th root of unity + // for this prime (2^128 mod kSmallPrime ≠ 1). Verify and expect a throw. + if (utils::powModSimple(u64(2), 2 * kSmallDegree, kSmallPrime) != 1) { + EXPECT_THROW( + utils::registerCustomPsi(kSmallDegree, kSmallPrime, u64(2)), + std::invalid_argument); + } else { + GTEST_SKIP() << "2 happens to be a root for this prime; skipping"; + } +} + +// ---------------------------------------------------------------------------- +// CUSTOM NTT: missing registration +// ---------------------------------------------------------------------------- + +TEST(NTTRootTypeValidation, CustomThrowsWhenNotRegistered) { + // kSmallPrime ≡ 1 mod 512, so degree=256 is NTT-friendly. (degree=128 + // can't be used here because CyclicNTTPolyMul registers a 4N-th zeta + // under the key (2*64, kSmallPrime) = (128, kSmallPrime); that test + // may run earlier and would pollute the registry for a degree=128 + // negacyclic lookup.) + utils::ScopedNTTRootType guard{utils::NTTRootType::CUSTOM}; + EXPECT_THROW((utils::NTT{256, kSmallPrime}), std::runtime_error); +} + +// ---------------------------------------------------------------------------- +// Polynomial multiplication equivalence +// +// The central correctness property: every valid primitive 2N-th root of unity +// yields the same negacyclic convolution ring Z[X]/(X^N+1, p). +// All three NTT modes must agree with the O(N^2) reference result. +// ---------------------------------------------------------------------------- + +TEST(NTTPolyMul, AllModesAgreeOnNegacyclicConvolution) { + constexpr u64 N = kSmallDegree; + constexpr u64 p = kSmallPrime; + + std::mt19937_64 rng(0xdeb1cafe); + vector a(N), b(N); + for (auto &x : a) + x = rng() % p; + for (auto &x : b) + x = rng() % p; + + const auto ref = negacyclicConv(a, b, p); + + // Compute NTT-based negacyclic convolution for a given root type. + auto nttConv = [&](utils::NTTRootType rt) { + utils::ScopedNTTRootType guard{rt}; + utils::NTT ntt{N, p}; + vector fa(a), fb(b), fc(N); + ntt.computeForward(fa.data()); + ntt.computeForward(fb.data()); + for (u64 i = 0; i < N; i++) + fc[i] = utils::mulModSimple(fa[i], fb[i], p); + ntt.computeBackward(fc.data()); + return fc; + }; + + // Register a custom psi so the CUSTOM path can be exercised. + utils::registerCustomPsi(N, p, findTestPsi(N, p)); + + EXPECT_EQ(nttConv(utils::NTTRootType::MIN), ref) << "MIN mode mismatch"; + EXPECT_EQ(nttConv(utils::NTTRootType::DIRECT), ref) + << "DIRECT mode mismatch"; + EXPECT_EQ(nttConv(utils::NTTRootType::CUSTOM), ref) + << "CUSTOM mode mismatch"; +} + +// Additional prime to confirm the equivalence is not prime-specific. +TEST(NTTPolyMul, AllModesAgreeOn40bitPrime) { + constexpr u64 N = 16; // small degree for fast naive reference + constexpr u64 p = + 2199020634113ULL; // 40-bit NTT prime (degree 8192-friendly) + + std::mt19937_64 rng(0xcafe0123); + vector a(N), b(N); + for (auto &x : a) + x = rng() % p; + for (auto &x : b) + x = rng() % p; + + const auto ref = negacyclicConv(a, b, p); + + auto nttConv = [&](utils::NTTRootType rt) { + utils::ScopedNTTRootType guard{rt}; + utils::NTT ntt{N, p}; + vector fa(a), fb(b), fc(N); + ntt.computeForward(fa.data()); + ntt.computeForward(fb.data()); + for (u64 i = 0; i < N; i++) + fc[i] = utils::mulModSimple(fa[i], fb[i], p); + ntt.computeBackward(fc.data()); + return fc; + }; + + utils::registerCustomPsi(N, p, findTestPsi(N, p)); + + EXPECT_EQ(nttConv(utils::NTTRootType::MIN), ref); + EXPECT_EQ(nttConv(utils::NTTRootType::DIRECT), ref); + EXPECT_EQ(nttConv(utils::NTTRootType::CUSTOM), ref); +} + +// ---------------------------------------------------------------------------- +// registerCustomPsi: overwrite (re-registration) +// ---------------------------------------------------------------------------- + +TEST(NTTReregistration, RoundtripAfterPsiOverwrite) { + constexpr u64 N = kSmallDegree; + constexpr u64 p = kSmallPrime; + + // Collect two distinct valid psi values using different bases. + vector valid_psi; + for (u64 base : {u64(3), u64(5), u64(7), u64(11), u64(13), u64(17)}) { + u64 candidate = utils::powModSimple(base, (p - 1) / (2 * N), p); + if (candidate != 1 && utils::powModSimple(candidate, N, p) != 1) { + // Avoid duplicates + bool dup = false; + for (u64 v : valid_psi) + if (v == candidate) { + dup = true; + break; + } + if (!dup) + valid_psi.push_back(candidate); + } + if (valid_psi.size() == 2) + break; + } + + if (valid_psi.size() < 2) + GTEST_SKIP() << "Could not find two distinct valid psi values"; + + // Register first psi, then overwrite with the second. + utils::registerCustomPsi(N, p, valid_psi[0]); + utils::registerCustomPsi(N, p, valid_psi[1]); + + // Round-trip must work with the overwritten psi. + utils::ScopedNTTRootType guard{utils::NTTRootType::CUSTOM}; + utils::NTT ntt{N, p}; + + std::mt19937_64 rng(0xbeef); + vector v(N); + for (auto &x : v) + x = rng() % p; + vector result(v); + ntt.computeForward(result.data()); + ntt.computeBackward(result.data()); + + EXPECT_EQ(result, v); +} + +// ============================================================================ +// FFT degree-sensitivity tests +// +// FFT(N) and FFT(2N) share the first half of their roots_ table by +// construction: the gap doubles in FFT(2N) and the modulus doubles too, so +// the angle exp(i*pi * (5^i * gap) / double_degree) is identical. +// Consequence: +// * Feeding the *same* size-(N/2) message to both produces the same result +// (only roots_[1..N/2-1] is touched, and those entries coincide). +// * Feeding messages of *different* sizes (size N/2 to FFT(N), size N to +// FFT(2N), even with matching prefix) engages additional butterfly +// stages in FFT(2N), so the outputs diverge. +// ============================================================================ + +TEST(FftDegreeSensitivity, ForwardFFTSameForNAnd2NWithSameSizeMsg) { + constexpr u64 degree = 128; + constexpr Size num_slots = degree / 2; + + std::mt19937_64 rng(0xcafe5678); + Message msg_n(num_slots), msg_2n(num_slots); + for (Size i = 0; i < num_slots; ++i) { + ComplexT val{static_cast(rng() % 1000), + static_cast(rng() % 1000)}; + msg_n[i] = val; + msg_2n[i] = val; + } + + utils::FFT fft_n(degree); + utils::FFT fft_2n(2 * degree); + + fft_n.forwardFFT(msg_n); + fft_2n.forwardFFT(msg_2n); + + for (Size i = 0; i < num_slots; ++i) { + EXPECT_NEAR(msg_n[i].real(), msg_2n[i].real(), 1e-10); + EXPECT_NEAR(msg_n[i].imag(), msg_2n[i].imag(), 1e-10); + } +} + +TEST(FftDegreeSensitivity, BackwardFFTSameForNAnd2NWithSameSizeMsg) { + constexpr u64 degree = 128; + constexpr Size num_slots = degree / 2; + + std::mt19937_64 rng(0xdead8765); + Message msg_n(num_slots), msg_2n(num_slots); + for (Size i = 0; i < num_slots; ++i) { + ComplexT val{static_cast(rng() % 1000), + static_cast(rng() % 1000)}; + msg_n[i] = val; + msg_2n[i] = val; + } + + utils::FFT fft_n(degree); + utils::FFT fft_2n(2 * degree); + + fft_n.backwardFFT(msg_n); + fft_2n.backwardFFT(msg_2n); + + for (Size i = 0; i < num_slots; ++i) { + EXPECT_NEAR(msg_n[i].real(), msg_2n[i].real(), 1e-10); + EXPECT_NEAR(msg_n[i].imag(), msg_2n[i].imag(), 1e-10); + } +} + +TEST(FftDegreeSensitivity, ForwardFFTDiffersWhenMsgSizesDiffer) { + constexpr u64 degree = 128; + constexpr Size num_slots_n = degree / 2; + constexpr Size num_slots_2n = degree; + + std::mt19937_64 rng(0xfeed1111); + Message msg_n(num_slots_n), msg_2n(num_slots_2n); + for (Size i = 0; i < num_slots_n; ++i) { + ComplexT val{static_cast(rng() % 1000), + static_cast(rng() % 1000)}; + msg_n[i] = val; + msg_2n[i] = val; + } + for (Size i = num_slots_n; i < num_slots_2n; ++i) { + msg_2n[i] = {static_cast(rng() % 1000), + static_cast(rng() % 1000)}; + } + + utils::FFT fft_n(degree); + utils::FFT fft_2n(2 * degree); + + fft_n.forwardFFT(msg_n); + fft_2n.forwardFFT(msg_2n); + + bool any_differ = false; + for (Size i = 0; i < num_slots_n && !any_differ; ++i) { + any_differ = std::abs(msg_n[i].real() - msg_2n[i].real()) > 1e-10 || + std::abs(msg_n[i].imag() - msg_2n[i].imag()) > 1e-10; + } + EXPECT_TRUE(any_differ) + << "FFT(N) on size-(N/2) msg and FFT(2N) on size-N msg with matching " + "prefix must produce different forwardFFT results"; +} + +TEST(FftDegreeSensitivity, BackwardFFTDiffersWhenMsgSizesDiffer) { + constexpr u64 degree = 128; + constexpr Size num_slots_n = degree / 2; + constexpr Size num_slots_2n = degree; + + std::mt19937_64 rng(0xbeef2222); + Message msg_n(num_slots_n), msg_2n(num_slots_2n); + for (Size i = 0; i < num_slots_n; ++i) { + ComplexT val{static_cast(rng() % 1000), + static_cast(rng() % 1000)}; + msg_n[i] = val; + msg_2n[i] = val; + } + for (Size i = num_slots_n; i < num_slots_2n; ++i) { + msg_2n[i] = {static_cast(rng() % 1000), + static_cast(rng() % 1000)}; + } + + utils::FFT fft_n(degree); + utils::FFT fft_2n(2 * degree); + + fft_n.backwardFFT(msg_n); + fft_2n.backwardFFT(msg_2n); + + bool any_differ = false; + for (Size i = 0; i < num_slots_n && !any_differ; ++i) { + any_differ = std::abs(msg_n[i].real() - msg_2n[i].real()) > 1e-10 || + std::abs(msg_n[i].imag() - msg_2n[i].imag()) > 1e-10; + } + EXPECT_TRUE(any_differ) + << "FFT(N) on size-(N/2) msg and FFT(2N) on size-N msg with matching " + "prefix must produce different backwardFFT results"; +} + +// ---------------------------------------------------------------------------- +// DIRECT vs MIN: roots should differ (algorithm sanity check) +// +// MIN selects the *minimum* primitive 2N-th root; DIRECT takes the first one +// found by the 2-adic search. For most primes these produce different values, +// confirming the two code paths are actually distinct. +// ---------------------------------------------------------------------------- + +// ============================================================================ +// NTTFactory dispatch +// +// Verifies that setNTTFactory()-registered backends are picked up by +// createNTT(), and that returning an empty shared_ptr from the custom factory +// falls through to the default makeNTT() path. The custom NTT here is a +// stub — it does not implement a real transform, just stamps a marker into +// the buffer so the test can confirm the stub was actually invoked. +// ============================================================================ + +namespace { + +class MarkerNTT : public utils::NTT_base { +public: + static constexpr u64 kForwardMarker = 0xDEADBEEFULL; + static constexpr u64 kBackwardMarker = 0xCAFEBABEULL; + + MarkerNTT() = default; + + void computeForward(u64 *op) const override { op[0] = kForwardMarker; } + void computeBackward(u64 *op) const override { op[0] = kBackwardMarker; } +}; + +// Restores the default (empty) factory on scope exit so a failing assertion +// does not leak custom-factory state into subsequent tests. +struct ResetNTTFactoryOnExit { + ~ResetNTTFactoryOnExit() { utils::setNTTFactory({}); } +}; + +} // namespace + +TEST(NTTFactoryDispatch, CustomFactoryIsInvoked) { + ResetNTTFactoryOnExit reset; + + bool factory_called = false; + utils::setNTTFactory( + [&](u64, u64, utils::NTTType, + utils::NTTRootType) -> std::shared_ptr> { + factory_called = true; + return std::make_shared(); + }); + + auto ntt = + utils::createNTT(64, 4295688193ULL, utils::NTTType::NEGACYCLIC); + EXPECT_TRUE(factory_called); + ASSERT_NE(ntt, nullptr); + EXPECT_NE(dynamic_cast(ntt.get()), nullptr); + + std::vector op(64, 0); + ntt->computeForward(op.data()); + EXPECT_EQ(op[0], MarkerNTT::kForwardMarker); + ntt->computeBackward(op.data()); + EXPECT_EQ(op[0], MarkerNTT::kBackwardMarker); +} + +TEST(NTTFactoryDispatch, EmptyReturnFallsThroughToDefault) { + ResetNTTFactoryOnExit reset; + + bool factory_called = false; + utils::setNTTFactory( + [&](u64, u64, utils::NTTType, + utils::NTTRootType) -> std::shared_ptr> { + factory_called = true; + return {}; + }); + + auto ntt = + utils::createNTT(64, 4295688193ULL, utils::NTTType::NEGACYCLIC); + EXPECT_TRUE(factory_called); + ASSERT_NE(ntt, nullptr); + EXPECT_NE(dynamic_cast *>(ntt.get()), nullptr); + EXPECT_EQ(dynamic_cast(ntt.get()), nullptr); +} + +TEST(NTTRootTypeAlgo, DirectAndMinUseDifferentPsiForTypicalPrime) { + // Build NTTs with both modes and check their first forward twiddle factor + // (psi_rev_[1]) differs. We do this indirectly: apply forward NTT to + // the unit vector e_1 = [0,1,0,...,0] — the result is the list of psi + // powers in bit-reversed order, so result[0] and result[1] expose psi. + constexpr u64 N = kSmallDegree; + constexpr u64 p = kSmallPrime; + + auto getPsiFromNTT = [&](utils::NTTRootType rt) { + utils::ScopedNTTRootType guard{rt}; + utils::NTT ntt{N, p}; + vector e1(N, 0); + e1[1] = 1; + ntt.computeForward(e1.data()); + // e1 after forward NTT: e1[i] = psi^(bit_reverse(i)) in some ordering. + // The element at index 1 (bit-reversed: N/2) gives psi^(N/2). + // Regardless, we just need to observe that the two NTTs yield + // *different* output vectors, which is sufficient to confirm the paths + // diverge. + return e1; + }; + + auto min_out = getPsiFromNTT(utils::NTTRootType::MIN); + auto direct_out = getPsiFromNTT(utils::NTTRootType::DIRECT); + + // For this well-known prime/degree pair the two roots are different. + // If they happen to coincide (astronomically unlikely), the test is moot + // but not a correctness failure — skip rather than fail. + if (min_out == direct_out) + GTEST_SKIP() << "MIN and DIRECT happened to select the same psi"; + + EXPECT_NE(min_out, direct_out) + << "MIN min-root selection and DIRECT 2-adic search should " + "yield different primitive roots for this prime"; +} diff --git a/test/TestBase.hpp b/test/TestBase.hpp index 11a97a4..85e3af4 100644 --- a/test/TestBase.hpp +++ b/test/TestBase.hpp @@ -40,6 +40,10 @@ using namespace deb; #define DEB_TEST_EXPECT(statement) EXPECT_DEATH(statement, ".*") #endif +#define X(preset) deb::PRESET_##preset, +inline std::vector Presets = {PRESET_LIST}; +#undef X + using MSGS = std::vector; using FMSGS = std::vector; using COEFFS = std::vector; @@ -47,10 +51,10 @@ using FCOEFFS = std::vector; class DebTestBase : public ::testing::TestWithParam { public: - const Preset preset{GetParam()}; - const Size num_slots{get_num_slots(preset)}; - const Size degree{get_degree(preset)}; - const Size num_secret{get_num_secret(preset)}; + Preset preset{GetParam()}; + Size num_slots{get_num_slots(preset)}; + Size degree{get_degree(preset)}; + Size num_secret{get_num_secret(preset)}; Encryptor encryptor{preset}; Decryptor decryptor{preset}; @@ -184,6 +188,102 @@ class DebTestBase : public ::testing::TestWithParam { } } } + // Scales degree complex Message/FMessage values (both real and imag) for + // unscaled presets. For real-HEAAN where the message has complex slots. + template T scale_complex_message(T &msg, uint32_t level) { + const double scale = get_scale_factors(preset)[level]; + if (scale == 0.0) { + const double sc = + std::pow(2.0, utils::bitWidth(get_primes(preset)[0]) - 4); + T scale_msg = gen_empty_real_message(); + for (Size i = 0; i < num_secret; ++i) { + for (Size j = 0; j < degree; ++j) { + if constexpr (std::is_same_v) { + scale_msg[i][j].real( + static_cast(msg[i][j].real() * sc)); + scale_msg[i][j].imag( + static_cast(msg[i][j].imag() * sc)); + } else { + scale_msg[i][j].real(msg[i][j].real() * sc); + scale_msg[i][j].imag(msg[i][j].imag() * sc); + } + } + } + return scale_msg; + } + return msg; + } + // Compares all degree complex values (real + imag). + // Used for real-HEAAN Message tests where the decoded output is a + // Galois-Hermitian symmetric complex Message. + template + void compare_heaan_msg(T &msg1, T &msg2, double tol) const { + for (Size i = 0; i < num_secret; ++i) { + for (Size j = 0; j < degree; ++j) { + ASSERT_NEAR(static_cast(msg1[i][j].real()), + static_cast(msg2[i][j].real()), tol); + ASSERT_NEAR(static_cast(msg1[i][j].imag()), + static_cast(msg2[i][j].imag()), tol); + } + } + } + template T gen_random_real_message() { + T msg; + for (Size i = 0; i < num_secret; ++i) { + if constexpr (std::is_same_v) { + FMessage m(degree); + for (Size j = 0; j < degree; ++j) { + m[j].real(static_cast(dist(gen))); + m[j].imag(0.0f); + } + msg.emplace_back(std::move(m)); + } else if constexpr (std::is_same_v) { + Message m(degree); + for (Size j = 0; j < degree; ++j) { + m[j].real(dist(gen)); + m[j].imag(0.0); + } + msg.emplace_back(std::move(m)); + } + } + return msg; + } + template T gen_empty_real_message() { + T msg; + for (Size i = 0; i < num_secret; ++i) { + msg.emplace_back(degree); + } + return msg; + } + template T scale_real_message(T &msg, uint32_t level) { + const double scale = get_scale_factors(preset)[level]; + if (scale == 0.0) { + const double sc = + std::pow(2.0, utils::bitWidth(get_primes(preset)[0]) - 4); + T scale_msg = gen_empty_real_message(); + for (Size i = 0; i < num_secret; ++i) { + for (Size j = 0; j < degree; ++j) { + if constexpr (std::is_same_v) { + scale_msg[i][j].real( + static_cast(msg[i][j].real() * sc)); + } else { + scale_msg[i][j].real(msg[i][j].real() * sc); + } + } + } + return scale_msg; + } + return msg; + } + template + void compare_real_msg(T &msg1, T &msg2, double tol) const { + for (Size i = 0; i < num_secret; ++i) { + for (Size j = 0; j < degree; ++j) { + ASSERT_NEAR(static_cast(msg1[i][j].real()), + static_cast(msg2[i][j].real()), tol); + } + } + } template void compareArray(const T *arr1, const T *arr2, const Size size) { for (Size i = 0; i < size; ++i) { diff --git a/test/U32-test.cpp b/test/U32-test.cpp index 99366b3..98b44a7 100644 --- a/test/U32-test.cpp +++ b/test/U32-test.cpp @@ -31,6 +31,7 @@ */ #include "CKKSTypes.hpp" +#include "KeyGenerator.hpp" #include "TestBase.hpp" #include "utils/Basic.hpp" #include "utils/ModArith.hpp" @@ -130,7 +131,7 @@ INSTANTIATE_TEST_SUITE_P(ShoupPrimes, ComputeShoupTest, // NTT // ========================================================================= -class NttU32Test : public ::testing::TestWithParam> { +class NTTU32Test : public ::testing::TestWithParam> { public: u64 degree{std::get<0>(GetParam())}; u64 prime{std::get<1>(GetParam())}; @@ -148,7 +149,7 @@ class NttU32Test : public ::testing::TestWithParam> { } }; -TEST_P(NttU32Test, RoundTrip) { +TEST_P(NTTU32Test, RoundTrip) { utils::NTT ntt{degree, prime}; auto orig = randomVector(); @@ -160,7 +161,7 @@ TEST_P(NttU32Test, RoundTrip) { EXPECT_EQ(v, orig); } -TEST_P(NttU32Test, OneZeroVector) { +TEST_P(NTTU32Test, OneZeroVector) { // NTT([1, 0, 0, ...]) == [1, 1, 1, ...] utils::NTT ntt{degree, prime}; @@ -172,7 +173,7 @@ TEST_P(NttU32Test, OneZeroVector) { EXPECT_EQ(op, expected); } -TEST_P(NttU32Test, ConvolutionViaPointwiseMul) { +TEST_P(NTTU32Test, ConvolutionViaPointwiseMul) { // For negacyclic NTT: NTT(a) * NTT(b) == NTT(a * b mod (X^N + 1)). // Verify that a simple polynomial product through NTT is consistent. utils::NTT ntt{degree, prime}; @@ -212,7 +213,7 @@ TEST_P(NttU32Test, ConvolutionViaPointwiseMul) { } #ifdef DEB_U64 -TEST_P(NttU32Test, CompareU64) { +TEST_P(NTTU32Test, CompareU64) { utils::NTT ntt32{degree, prime}; utils::NTT ntt64{degree, prime}; @@ -236,13 +237,13 @@ TEST_P(NttU32Test, CompareU64) { #endif INSTANTIATE_TEST_SUITE_P( - U32NttParams, NttU32Test, + U32NTTParams, NTTU32Test, testing::Values(std::tuple{DEGREE, PRIME_A}, std::tuple{DEGREE, PRIME_B}, std::tuple{DEGREE, PRIME_C}, std::tuple{2 * DEGREE, PRIME_A}, - std::tuple{2 * DEGREE, 2147377153}, - std::tuple{2 * DEGREE, 2147352577})); + std::tuple{2 * DEGREE, 1073692673}, + std::tuple{2 * DEGREE, 1073668097})); // ========================================================================= // ModArith<1, u32> @@ -450,9 +451,9 @@ TEST_F(PolyUnitU32Test, SetPrime) { TEST_F(PolyUnitU32Test, SetNTTFlag) { PolyUnitT pu(prime, deg); EXPECT_FALSE(pu.isNTT()); - pu.setNTT(true); + pu.setNTT(utils::NTTType::NEGACYCLIC); EXPECT_TRUE(pu.isNTT()); - pu.setNTT(false); + pu.setNTT(utils::NTTType::NONNTT); EXPECT_FALSE(pu.isNTT()); } @@ -461,7 +462,7 @@ TEST_F(PolyUnitU32Test, SetNTTFlag) { // ========================================================================= #ifdef DEB_U64 -class NttU32vsU64Test : public ::testing::Test { +class NTTU32vsU64Test : public ::testing::Test { public: static constexpr u64 prime = PRIME_A; static constexpr Size degree = DEGREE; @@ -477,7 +478,7 @@ class NttU32vsU64Test : public ::testing::Test { } }; -TEST_F(NttU32vsU64Test, ForwardNTTMatchesU64) { +TEST_F(NTTU32vsU64Test, ForwardNTTMatchesU64) { utils::NTT ntt32{degree, prime}; utils::NTT ntt64{degree, prime}; @@ -493,7 +494,7 @@ TEST_F(NttU32vsU64Test, ForwardNTTMatchesU64) { } } -TEST_F(NttU32vsU64Test, BackwardNTTMatchesU64) { +TEST_F(NTTU32vsU64Test, BackwardNTTMatchesU64) { utils::NTT ntt32{degree, prime}; utils::NTT ntt64{degree, prime}; @@ -677,3 +678,124 @@ const std::vector all_presets32 = {PRESET_LIST_U32 }; INSTANTIATE_TEST_SUITE_P(Presets, Endecrypt32Test, testing::ValuesIn(all_presets32)); + +// ========================================================================= +// KeyGenerator32: switching-key generation for u32 presets. +// +// Mirrors KeyGen-test.cpp (u64) but exercises KeyGeneratorT +// against the u32-dedicated preset(s). Verifies that each genXxxKey API: +// - runs without exceptions (functional + Inplace forms), +// - produces switching keys of the expected ax/bx dimensions for the preset. +// ========================================================================= + +class KeyGen32Test : public ::testing::TestWithParam { +public: + const Preset preset{GetParam()}; + const Size num_slots{get_num_slots(preset)}; + const Size num_secret{get_num_secret(preset)}; + const Size degree{get_degree(preset)}; + const Size gadget_rank{get_gadget_rank(preset)}; + const Size num_p{get_num_p(preset)}; + + KeyGenerator32 keygen{preset}; + SecretKey32 sk{SecretKeyGenerator32::GenSecretKey(preset)}; + + std::mt19937 gen{std::random_device{}()}; + std::uniform_int_distribution dist_u64{0, UINT64_MAX}; +}; + +TEST_P(KeyGen32Test, GenEncryptionKey) { + SwitchKey32 enckey(preset, SwitchKeyKind::SWK_ENC); + ASSERT_NO_THROW(enckey = keygen.genEncKey(sk)); + ASSERT_NO_THROW(keygen.genEncKeyInplace(enckey, sk)); + + ASSERT_EQ(enckey.axSize(), 1); + ASSERT_EQ(enckey.bxSize(), num_secret); + ASSERT_EQ(enckey.ax().size(), num_p); + ASSERT_EQ(enckey.bx().size(), num_p); +} + +TEST_P(KeyGen32Test, GenMultiplicationKey) { + SwitchKey32 mulkey(preset, SwitchKeyKind::SWK_MULT); + ASSERT_NO_THROW(mulkey = keygen.genMultKey(sk)); + ASSERT_NO_THROW(keygen.genMultKeyInplace(mulkey, sk)); + + ASSERT_EQ(mulkey.axSize(), gadget_rank); + ASSERT_EQ(mulkey.bxSize(), gadget_rank * num_secret); + ASSERT_EQ(mulkey.ax().size(), num_p); + ASSERT_EQ(mulkey.bx().size(), num_p); +} + +TEST_P(KeyGen32Test, GenConjugationKey) { + SwitchKey32 conjkey(preset, SwitchKeyKind::SWK_CONJ); + ASSERT_NO_THROW(conjkey = keygen.genConjKey(sk)); + ASSERT_NO_THROW(keygen.genConjKeyInplace(conjkey, sk)); + + ASSERT_EQ(conjkey.axSize(), gadget_rank); + ASSERT_EQ(conjkey.bxSize(), gadget_rank * num_secret); + ASSERT_EQ(conjkey.ax().size(), num_p); + ASSERT_EQ(conjkey.bx().size(), num_p); +} + +TEST_P(KeyGen32Test, GenRotationKeys) { + const Size rot = dist_u64(gen) % (num_slots - 1) + 1; + + SwitchKey32 left_rotkey(preset, SwitchKeyKind::SWK_ROT, rot); + ASSERT_NO_THROW(left_rotkey = keygen.genLeftRotKey(rot, sk)); + ASSERT_NO_THROW(keygen.genLeftRotKeyInplace(rot, left_rotkey, sk)); + + SwitchKey32 right_rotkey(preset, SwitchKeyKind::SWK_ROT, num_slots - rot); + ASSERT_NO_THROW(right_rotkey = keygen.genRightRotKey(num_slots - rot, sk)); + ASSERT_NO_THROW( + keygen.genRightRotKeyInplace(num_slots - rot, right_rotkey, sk)); + + ASSERT_EQ(left_rotkey.axSize(), gadget_rank); + ASSERT_EQ(left_rotkey.bxSize(), gadget_rank * num_secret); + ASSERT_EQ(left_rotkey.ax().size(), num_p); + ASSERT_EQ(left_rotkey.bx().size(), num_p); + + ASSERT_EQ(right_rotkey.axSize(), gadget_rank); + ASSERT_EQ(right_rotkey.bxSize(), gadget_rank * num_secret); + ASSERT_EQ(right_rotkey.ax().size(), num_p); + ASSERT_EQ(right_rotkey.bx().size(), num_p); +} + +TEST_P(KeyGen32Test, GenAutomorphismKey) { + const Size sig = dist_u64(gen) % (degree - 1) + 1; + SwitchKey32 autokey(preset, SwitchKeyKind::SWK_AUTO); + ASSERT_NO_THROW(autokey = keygen.genAutoKey(sig, sk)); + ASSERT_NO_THROW(keygen.genAutoKeyInplace(sig, autokey, sk)); + + ASSERT_EQ(autokey.axSize(), gadget_rank); + ASSERT_EQ(autokey.bxSize(), gadget_rank * num_secret); + ASSERT_EQ(autokey.ax().size(), num_p); + ASSERT_EQ(autokey.bx().size(), num_p); +} + +TEST_P(KeyGen32Test, GenModPackKey) { + if (num_secret != 1) { + GTEST_SKIP() << "MODPACK key generation is only for single secret."; + } + std::vector modkey; + ASSERT_NO_THROW(modkey = keygen.genModPackKeyBundle(sk, sk)); + ASSERT_NO_THROW(keygen.genModPackKeyBundleInplace(sk, sk, modkey)); +} + +TEST_P(KeyGen32Test, GenModPackKeySelf) { + if (num_secret != 1) { + GTEST_SKIP() + << "MODPACK_SELF key generation is only for single secret."; + } + const Size pad_rank = 1U << (dist_u64(gen) % (get_log_degree(preset) / 2)); + SwitchKey32 modkey(preset, SwitchKeyKind::SWK_MODPACK_SELF); + ASSERT_NO_THROW(modkey = keygen.genModPackKeyBundle(pad_rank, sk)); + ASSERT_NO_THROW(keygen.genModPackKeyBundleInplace(pad_rank, modkey, sk)); + + ASSERT_EQ(modkey.axSize(), pad_rank); + ASSERT_EQ(modkey.bxSize(), pad_rank * num_secret); + ASSERT_EQ(modkey.ax().size(), num_p); + ASSERT_EQ(modkey.bx().size(), num_p); +} + +INSTANTIATE_TEST_SUITE_P(Presets, KeyGen32Test, + testing::ValuesIn(all_presets32)); From c3ec17478d5e5ff9d5b3a9adbdb392bc746a0971 Mon Sep 17 00:00:00 2001 From: SeungjunLee Date: Wed, 27 May 2026 07:34:52 +0000 Subject: [PATCH 2/4] remove bundling.cmake --- cmake/bundling.cmake | 65 -------------------------------------------- 1 file changed, 65 deletions(-) delete mode 100644 cmake/bundling.cmake diff --git a/cmake/bundling.cmake b/cmake/bundling.cmake deleted file mode 100644 index 29eba00..0000000 --- a/cmake/bundling.cmake +++ /dev/null @@ -1,65 +0,0 @@ -# ~~~ -# Copyright 2026 CryptoLab, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ~~~ - -# Merge the object files of `dependency` target with that of `target`. Both -# should be static or object libraries; otherwise this is no-op. -function(merge_archive_if_static target dependency) - # Check if the target has object files - list(APPEND TYPES_HAVING_OBJECTS "STATIC_LIBRARY" "OBJECT_LIBRARY") - get_target_property(IS_STATIC ${target} TYPE) - if(NOT IS_STATIC IN_LIST TYPES_HAVING_OBJECTS) - return() - endif() - - # Check if the dependency is a target and has object files - if(NOT TARGET ${dependency}) - return() - endif() - get_target_property(IS_STATIC ${dependency} TYPE) - if(NOT IS_STATIC IN_LIST TYPES_HAVING_OBJECTS) - return() - endif() - - if(CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang") - add_custom_command( - TARGET ${target} - POST_BUILD - COMMAND rm -rf ${target}_objs && mkdir ${target}_objs - COMMAND rm -rf ${dependency}_objs && mkdir ${dependency}_objs - COMMAND ${CMAKE_COMMAND} -E chdir ${target}_objs ${CMAKE_AR} -x - $ - COMMAND ${CMAKE_COMMAND} -E chdir ${dependency}_objs ${CMAKE_AR} -x - $ - COMMAND ar -qcs $ ${target}_objs/*.o - ${dependency}_objs/*.o - COMMAND rm -rf ${target}_objs ${dependency}_objs - WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) # DEPENDS ${target} - # ${dependency}) - elseif(CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") - add_custom_command( - TARGET ${target} - POST_BUILD - COMMAND lib.exe /OUT:$ $ - $ - WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) # DEPENDS ${target} - # ${dependency}) - else() - message( - WARNING - "Failed merging ${target} target with ${dependency}: unsupported compiler" - ) - endif() -endfunction() From cf3c85134fe20075f774a3cdad747c36bcfd45ee Mon Sep 17 00:00:00 2001 From: SeungjunLee Date: Thu, 28 May 2026 00:52:41 +0000 Subject: [PATCH 3/4] remove unnecessary commented line --- CMakeLists.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index aad7181..8d83276 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -31,7 +31,6 @@ set(CMAKE_EXPORT_COMPILE_COMMANDS ON) set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_EXTENSIONS OFF) -# set(CMAKE_CXX_VISIBILITY_PRESET "default") set(CMAKE_VISIBILITY_INLINES_HIDDEN OFF) set(CMAKE_INSTALL_PREFIX From eefd5822ad49446d42b714b896e7c70ccef24c73 Mon Sep 17 00:00:00 2001 From: SeungjunLee Date: Thu, 28 May 2026 01:50:32 +0000 Subject: [PATCH 4/4] remove std namespace using-directive --- CMakeLists.txt | 2 +- examples/EnDecryption-MultiSecret.cpp | 1 - examples/EnDecryption-Real.cpp | 1 - examples/EnDecryption.cpp | 1 - examples/KeyGeneration.cpp | 1 - examples/SeedOnlySecretKey.cpp | 5 ++- examples/Serialization.cpp | 5 ++- test/NTT-test.cpp | 51 +++++++++++++-------------- 8 files changed, 30 insertions(+), 37 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 8d83276..e57f3b3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -17,7 +17,7 @@ cmake_minimum_required(VERSION 3.21) project( deb - VERSION 0.3.1 + VERSION 0.4.0 LANGUAGES CXX DESCRIPTION "CryptoLab's official cryptosystem library for FHE.") diff --git a/examples/EnDecryption-MultiSecret.cpp b/examples/EnDecryption-MultiSecret.cpp index 64154a3..5b541ff 100644 --- a/examples/EnDecryption-MultiSecret.cpp +++ b/examples/EnDecryption-MultiSecret.cpp @@ -16,7 +16,6 @@ #include "ExampleUtils.hpp" -using namespace std; using namespace deb; int main() { diff --git a/examples/EnDecryption-Real.cpp b/examples/EnDecryption-Real.cpp index dd553a1..6757c34 100644 --- a/examples/EnDecryption-Real.cpp +++ b/examples/EnDecryption-Real.cpp @@ -17,7 +17,6 @@ #include "ExampleUtils.hpp" #include -using namespace std; using namespace deb; namespace { diff --git a/examples/EnDecryption.cpp b/examples/EnDecryption.cpp index 54a6542..ec38284 100644 --- a/examples/EnDecryption.cpp +++ b/examples/EnDecryption.cpp @@ -17,7 +17,6 @@ #include "ExampleUtils.hpp" #include -using namespace std; using namespace deb; int main() { diff --git a/examples/KeyGeneration.cpp b/examples/KeyGeneration.cpp index b228030..85e50e7 100644 --- a/examples/KeyGeneration.cpp +++ b/examples/KeyGeneration.cpp @@ -16,7 +16,6 @@ #include "ExampleUtils.hpp" -using namespace std; using namespace deb; int main() { diff --git a/examples/SeedOnlySecretKey.cpp b/examples/SeedOnlySecretKey.cpp index 56b4a43..87802cf 100644 --- a/examples/SeedOnlySecretKey.cpp +++ b/examples/SeedOnlySecretKey.cpp @@ -19,7 +19,6 @@ #include "Serialize.hpp" #endif -using namespace std; using namespace deb; int main() { @@ -36,12 +35,12 @@ int main() { #ifdef DEB_SERIALIZE // Serialize seed only secret key - ostringstream os; + std::ostringstream os; serializeToStream(seed_only_sk, os); std::cout << "Serialized secret key size (seed only): " << os.str().size() << " bytes" << std::endl; // Serialize coeff only secret key - os = ostringstream(); // Clear the stream + os = std::ostringstream(); // Clear the stream serializeToStream(coeff_only_sk, os); std::cout << "Serialized secret key size (coeff only): " << os.str().size() << " bytes" << std::endl; #endif diff --git a/examples/Serialization.cpp b/examples/Serialization.cpp index bfc62a1..ca15c10 100644 --- a/examples/Serialization.cpp +++ b/examples/Serialization.cpp @@ -20,7 +20,6 @@ #include #include -using namespace std; using namespace deb; int main() { @@ -51,14 +50,14 @@ int main() { Ciphertext cipher(preset); encryptor.encrypt(msg, enckey, cipher); - ofstream of(tmp_dir + "serialize_example1.bin", ios::binary); + std::ofstream of(tmp_dir + "serialize_example1.bin", std::ios::binary); serializeToStream(msg, of); serializeToStream(sk, of); serializeToStream(enckey, of); serializeToStream(cipher, of); of.close(); - ifstream inf(tmp_dir + "serialize_example1.bin", ios::binary); + std::ifstream inf(tmp_dir + "serialize_example1.bin", std::ios::binary); Message msg2(preset); SecretKey sk2(preset); SwitchKey enckey2(preset, SWK_ENC); diff --git a/test/NTT-test.cpp b/test/NTT-test.cpp index 0f6f8c3..c764900 100644 --- a/test/NTT-test.cpp +++ b/test/NTT-test.cpp @@ -24,12 +24,11 @@ #include using namespace deb; -using namespace std; class NTTTest : public ::testing::TestWithParam> { public: - const u64 degree{get<0>(GetParam())}; - const u64 prime{get<1>(GetParam())}; + const u64 degree{std::get<0>(GetParam())}; + const u64 prime{std::get<1>(GetParam())}; std::random_device rd; std::mt19937 gen{rd()}; @@ -226,13 +225,13 @@ INSTANTIATE_TEST_SUITE_P(CustomModeTiny, NTTTest, class CyclicNTTTest : public ::testing::TestWithParam> { public: - const u64 degree{get<0>(GetParam())}; - const u64 prime{get<1>(GetParam())}; + const u64 degree{std::get<0>(GetParam())}; + const u64 prime{std::get<1>(GetParam())}; std::mt19937_64 gen{std::random_device{}()}; - vector random_vec() { - vector v(degree); - uniform_int_distribution dist(0, prime - 1); + std::vector random_vec() { + std::vector v(degree); + std::uniform_int_distribution dist(0, prime - 1); for (auto &x : v) x = dist(gen); return v; @@ -265,10 +264,10 @@ TEST_P(CyclicNTTTest, ForwardOfConstantOne) { // ends up evaluating f(x) = 1 at every N-th root of unity — every bin is 1. utils::NTT_C ntt{degree, prime}; - vector op(degree, 0); + std::vector op(degree, 0); op[0] = 1; ntt.computeForward(op.data()); - EXPECT_EQ(op, vector(degree, 1)); + EXPECT_EQ(op, std::vector(degree, 1)); } TEST_P(CyclicNTTTest, DirectRoundTrip) { @@ -308,10 +307,10 @@ constexpr u64 kSmallDegree = 64; constexpr u64 kSmallPrime = 4295688193ULL; // Naive negacyclic convolution: c = a * b mod (X^N + 1) mod p. -vector negacyclicConv(const vector &a, const vector &b, - u64 prime) { +std::vector negacyclicConv(const std::vector &a, + const std::vector &b, u64 prime) { const u64 N = a.size(); - vector c(N, 0); + std::vector c(N, 0); for (u64 i = 0; i < N; i++) { for (u64 j = 0; j < N; j++) { u64 prod = utils::mulModSimple(a[i], b[j], prime); @@ -344,7 +343,7 @@ TEST(CyclicNTTPolyMul, AllModesAgreeOnPointwiseProduct) { constexpr u64 p = kSmallPrime; std::mt19937_64 rng(0xcafe1234); - vector a(N), b(N); + std::vector a(N), b(N); for (auto &x : a) x = rng() % p; for (auto &x : b) @@ -352,7 +351,7 @@ TEST(CyclicNTTPolyMul, AllModesAgreeOnPointwiseProduct) { auto nttMul = [&](utils::NTTRootType rt) { utils::NTT_C ntt{N, p, rt}; - vector fa(a), fb(b), fc(N); + std::vector fa(a), fb(b), fc(N); ntt.computeForward(fa.data()); ntt.computeForward(fb.data()); for (u64 i = 0; i < N; i++) @@ -382,7 +381,7 @@ TEST(CyclicNTTPolyMul, CyclicDiffersFromNegacyclic) { constexpr u64 p = kSmallPrime; std::mt19937_64 rng(0xdead5678); - vector a(N), b(N); + std::vector a(N), b(N); for (auto &x : a) x = rng() % p; for (auto &x : b) @@ -390,7 +389,7 @@ TEST(CyclicNTTPolyMul, CyclicDiffersFromNegacyclic) { auto cycMul = [&]() { utils::NTT_C ntt{N, p}; - vector fa(a), fb(b), fc(N); + std::vector fa(a), fb(b), fc(N); ntt.computeForward(fa.data()); ntt.computeForward(fb.data()); for (u64 i = 0; i < N; i++) @@ -400,7 +399,7 @@ TEST(CyclicNTTPolyMul, CyclicDiffersFromNegacyclic) { }; auto negMul = [&]() { utils::NTT ntt{N, p}; - vector fa(a), fb(b), fc(N); + std::vector fa(a), fb(b), fc(N); ntt.computeForward(fa.data()); ntt.computeForward(fb.data()); for (u64 i = 0; i < N; i++) @@ -486,7 +485,7 @@ TEST(NTTPolyMul, AllModesAgreeOnNegacyclicConvolution) { constexpr u64 p = kSmallPrime; std::mt19937_64 rng(0xdeb1cafe); - vector a(N), b(N); + std::vector a(N), b(N); for (auto &x : a) x = rng() % p; for (auto &x : b) @@ -498,7 +497,7 @@ TEST(NTTPolyMul, AllModesAgreeOnNegacyclicConvolution) { auto nttConv = [&](utils::NTTRootType rt) { utils::ScopedNTTRootType guard{rt}; utils::NTT ntt{N, p}; - vector fa(a), fb(b), fc(N); + std::vector fa(a), fb(b), fc(N); ntt.computeForward(fa.data()); ntt.computeForward(fb.data()); for (u64 i = 0; i < N; i++) @@ -524,7 +523,7 @@ TEST(NTTPolyMul, AllModesAgreeOn40bitPrime) { 2199020634113ULL; // 40-bit NTT prime (degree 8192-friendly) std::mt19937_64 rng(0xcafe0123); - vector a(N), b(N); + std::vector a(N), b(N); for (auto &x : a) x = rng() % p; for (auto &x : b) @@ -535,7 +534,7 @@ TEST(NTTPolyMul, AllModesAgreeOn40bitPrime) { auto nttConv = [&](utils::NTTRootType rt) { utils::ScopedNTTRootType guard{rt}; utils::NTT ntt{N, p}; - vector fa(a), fb(b), fc(N); + std::vector fa(a), fb(b), fc(N); ntt.computeForward(fa.data()); ntt.computeForward(fb.data()); for (u64 i = 0; i < N; i++) @@ -560,7 +559,7 @@ TEST(NTTReregistration, RoundtripAfterPsiOverwrite) { constexpr u64 p = kSmallPrime; // Collect two distinct valid psi values using different bases. - vector valid_psi; + std::vector valid_psi; for (u64 base : {u64(3), u64(5), u64(7), u64(11), u64(13), u64(17)}) { u64 candidate = utils::powModSimple(base, (p - 1) / (2 * N), p); if (candidate != 1 && utils::powModSimple(candidate, N, p) != 1) { @@ -590,10 +589,10 @@ TEST(NTTReregistration, RoundtripAfterPsiOverwrite) { utils::NTT ntt{N, p}; std::mt19937_64 rng(0xbeef); - vector v(N); + std::vector v(N); for (auto &x : v) x = rng() % p; - vector result(v); + std::vector result(v); ntt.computeForward(result.data()); ntt.computeBackward(result.data()); @@ -825,7 +824,7 @@ TEST(NTTRootTypeAlgo, DirectAndMinUseDifferentPsiForTypicalPrime) { auto getPsiFromNTT = [&](utils::NTTRootType rt) { utils::ScopedNTTRootType guard{rt}; utils::NTT ntt{N, p}; - vector e1(N, 0); + std::vector e1(N, 0); e1[1] = 1; ntt.computeForward(e1.data()); // e1 after forward NTT: e1[i] = psi^(bit_reverse(i)) in some ordering.