Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions R/cpp11.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@ qtl_enrichment_rcpp <- function(r_gwas_pip, r_qtl_susie_fit, pi_gwas, pi_qtl, Im
.Call(`_pecotmr_qtl_enrichment_rcpp`, r_gwas_pip, r_qtl_susie_fit, pi_gwas, pi_qtl, ImpN, shrinkage_lambda, double_shrinkage, bessel_correction, num_threads)
}

sdpr_rcpp <- function(bhat_r, LD, n, per_variant_sample_size, array, a, c, M, a0k, b0k, iter, burn, thin, n_threads, opt_llk, verbose, seed) {
.Call(`_pecotmr_sdpr_rcpp`, bhat_r, LD, n, per_variant_sample_size, array, a, c, M, a0k, b0k, iter, burn, thin, n_threads, opt_llk, verbose, seed)
sdpr_rcpp <- function(bhat_r, LD, n, per_variant_sample_size, array, a, c, M, a0k, b0k, iter, burn, thin, n_threads, opt_llk, verbose, seed, random_init) {
.Call(`_pecotmr_sdpr_rcpp`, bhat_r, LD, n, per_variant_sample_size, array, a, c, M, a0k, b0k, iter, burn, thin, n_threads, opt_llk, verbose, seed, random_init)
}
14 changes: 12 additions & 2 deletions R/regularized_regression.R
Original file line number Diff line number Diff line change
Expand Up @@ -148,11 +148,18 @@ prs_cs_weights <- function(stat, LD, ...) {
#' @param iter Number of iterations for MCMC. Default is 1000.
#' @param burn Number of burn-in iterations for MCMC. Default is 200.
#' @param thin Thinning interval for MCMC. Default is 5.
#' @param init Initialization mode for latent cluster assignments. Use `"random"`
#' to initialize SNPs uniformly across SDPR mixture components, matching the
#' original SDPR implementation, or `"null"` to start all SNPs in the null
#' cluster. Default is `"random"`.
#' @param n_threads Number of threads to use. Default is 1.
#' @param opt_llk Which likelihood to evaluate. 1 for equation 6 (slightly shrink the correlation of SNPs)
#' and 2 for equation 5 (SNPs genotyped on different arrays in a separate cohort).
#' Default is 1.
#' @param verbose Whether to print verbose output. Default is true.
#' @param seed Optional unsigned integer seed for the C++ SDPR sampler. When
#' \code{NULL}, the sampler uses \code{std::random_device} and is not
#' reproducible.
#'
#' @return A list containing the estimated effect sizes (beta) and heritability (h2).
#' @examples
Expand Down Expand Up @@ -203,8 +210,10 @@ prs_cs_weights <- function(stat, LD, ...) {
#'
#' @export
sdpr <- function(bhat, LD, n, per_variant_sample_size = NULL, array = NULL, a = 0.1, c = 1.0, M = 1000,
a0k = 0.5, b0k = 0.5, iter = 1000, burn = 200, thin = 5, n_threads = 1,
a0k = 0.5, b0k = 0.5, iter = 1000, burn = 200, thin = 5, init = c("random", "null"), n_threads = 1,
opt_llk = 1, verbose = TRUE, seed = NULL) {
init <- match.arg(init)

# Check if the sum of the rows in LD list is the same as length of bhat
if (sum(sapply(LD, nrow)) != length(bhat)) {
stop("The sum of the rows in LD list must be the same as the length of bhat.")
Expand Down Expand Up @@ -236,7 +245,8 @@ sdpr <- function(bhat, LD, n, per_variant_sample_size = NULL, array = NULL, a =
result <- sdpr_rcpp(
bhat, LD, as.integer(n), per_variant_sample_size, array, a, c, as.integer(M),
a0k, b0k, as.integer(iter), as.integer(burn), as.integer(thin),
as.integer(n_threads), as.integer(opt_llk), verbose, seed
as.integer(n_threads), as.integer(opt_llk), verbose, seed,
identical(init, "random")
)

return(result)
Expand Down
10 changes: 10 additions & 0 deletions man/sdpr.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions src/cpp11.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ extern "C" SEXP _pecotmr_qtl_enrichment_rcpp(SEXP r_gwas_pip, SEXP r_qtl_susie_f
END_CPP11
}
// sdpr.cpp
cpp11::writable::list sdpr_rcpp(const doubles& bhat_r, const list& LD, int n, sexp per_variant_sample_size, sexp array, double a, double c, int M, double a0k, double b0k, int iter, int burn, int thin, int n_threads, int opt_llk, bool verbose, sexp seed);
extern "C" SEXP _pecotmr_sdpr_rcpp(SEXP bhat_r, SEXP LD, SEXP n, SEXP per_variant_sample_size, SEXP array, SEXP a, SEXP c, SEXP M, SEXP a0k, SEXP b0k, SEXP iter, SEXP burn, SEXP thin, SEXP n_threads, SEXP opt_llk, SEXP verbose, SEXP seed) {
cpp11::writable::list sdpr_rcpp(const doubles& bhat_r, const list& LD, int n, sexp per_variant_sample_size, sexp array, double a, double c, int M, double a0k, double b0k, int iter, int burn, int thin, int n_threads, int opt_llk, bool verbose, sexp seed, bool random_init);
extern "C" SEXP _pecotmr_sdpr_rcpp(SEXP bhat_r, SEXP LD, SEXP n, SEXP per_variant_sample_size, SEXP array, SEXP a, SEXP c, SEXP M, SEXP a0k, SEXP b0k, SEXP iter, SEXP burn, SEXP thin, SEXP n_threads, SEXP opt_llk, SEXP verbose, SEXP seed, SEXP random_init) {
BEGIN_CPP11
return cpp11::as_sexp(sdpr_rcpp(cpp11::as_cpp<cpp11::decay_t<const doubles&>>(bhat_r), cpp11::as_cpp<cpp11::decay_t<const list&>>(LD), cpp11::as_cpp<cpp11::decay_t<int>>(n), cpp11::as_cpp<cpp11::decay_t<sexp>>(per_variant_sample_size), cpp11::as_cpp<cpp11::decay_t<sexp>>(array), cpp11::as_cpp<cpp11::decay_t<double>>(a), cpp11::as_cpp<cpp11::decay_t<double>>(c), cpp11::as_cpp<cpp11::decay_t<int>>(M), cpp11::as_cpp<cpp11::decay_t<double>>(a0k), cpp11::as_cpp<cpp11::decay_t<double>>(b0k), cpp11::as_cpp<cpp11::decay_t<int>>(iter), cpp11::as_cpp<cpp11::decay_t<int>>(burn), cpp11::as_cpp<cpp11::decay_t<int>>(thin), cpp11::as_cpp<cpp11::decay_t<int>>(n_threads), cpp11::as_cpp<cpp11::decay_t<int>>(opt_llk), cpp11::as_cpp<cpp11::decay_t<bool>>(verbose), cpp11::as_cpp<cpp11::decay_t<sexp>>(seed)));
return cpp11::as_sexp(sdpr_rcpp(cpp11::as_cpp<cpp11::decay_t<const doubles&>>(bhat_r), cpp11::as_cpp<cpp11::decay_t<const list&>>(LD), cpp11::as_cpp<cpp11::decay_t<int>>(n), cpp11::as_cpp<cpp11::decay_t<sexp>>(per_variant_sample_size), cpp11::as_cpp<cpp11::decay_t<sexp>>(array), cpp11::as_cpp<cpp11::decay_t<double>>(a), cpp11::as_cpp<cpp11::decay_t<double>>(c), cpp11::as_cpp<cpp11::decay_t<int>>(M), cpp11::as_cpp<cpp11::decay_t<double>>(a0k), cpp11::as_cpp<cpp11::decay_t<double>>(b0k), cpp11::as_cpp<cpp11::decay_t<int>>(iter), cpp11::as_cpp<cpp11::decay_t<int>>(burn), cpp11::as_cpp<cpp11::decay_t<int>>(thin), cpp11::as_cpp<cpp11::decay_t<int>>(n_threads), cpp11::as_cpp<cpp11::decay_t<int>>(opt_llk), cpp11::as_cpp<cpp11::decay_t<bool>>(verbose), cpp11::as_cpp<cpp11::decay_t<sexp>>(seed), cpp11::as_cpp<cpp11::decay_t<bool>>(random_init)));
END_CPP11
}

Expand All @@ -47,7 +47,7 @@ static const R_CallMethodDef CallEntries[] = {
{"_pecotmr_lassosum_rss_rcpp", (DL_FUNC) &_pecotmr_lassosum_rss_rcpp, 5},
{"_pecotmr_prs_cs_rcpp", (DL_FUNC) &_pecotmr_prs_cs_rcpp, 12},
{"_pecotmr_qtl_enrichment_rcpp", (DL_FUNC) &_pecotmr_qtl_enrichment_rcpp, 9},
{"_pecotmr_sdpr_rcpp", (DL_FUNC) &_pecotmr_sdpr_rcpp, 17},
{"_pecotmr_sdpr_rcpp", (DL_FUNC) &_pecotmr_sdpr_rcpp, 18},
{NULL, NULL, 0}
};
}
Expand Down
5 changes: 3 additions & 2 deletions src/sdpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ cpp11::writable::list sdpr_rcpp(
int n_threads = 1,
int opt_llk = 1,
bool verbose = true,
sexp seed = R_NilValue
sexp seed = R_NilValue,
bool random_init = true
) {
// Convert inputs to C++ types
std::vector<double> bhat = cpp11::as_cpp<std::vector<double>>(bhat_r);
Expand Down Expand Up @@ -64,7 +65,7 @@ cpp11::writable::list sdpr_rcpp(

// Call the mcmc function
std::unordered_map<std::string, vec> results = mcmc(
data, n, a, c, M, a0k, b0k, iter, burn, thin, n_threads, opt_llk, verbose, seed_val
data, n, a, c, M, a0k, b0k, iter, burn, thin, n_threads, opt_llk, verbose, seed_val, random_init
);

// Convert results to list
Expand Down
11 changes: 6 additions & 5 deletions src/sdpr_mcmc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,13 +144,13 @@ void MCMC_state::sample_assignment(size_t j, const mcmc_data &dat,
// Log-sum-exp for numerical stability (replaces SSE _mm_max_ps + exp_ps + _mm_hadd_ps)
float max_elem = prob.max();
float log_exp_sum = max_elem
+ std::logf(arma::accu(arma::exp(prob - max_elem)));
+ ::logf(arma::accu(arma::exp(prob - max_elem)));

// Categorical sampling via inverse CDF
// Original: mcmc.cpp lines 155-163
cls_assgn[i + start_i] = M - 1;
for (size_t k = 0; k < M - 1; k++) {
rnd_i -= std::expf(prob(k) - log_exp_sum);
rnd_i -= ::expf(prob(k) - log_exp_sum);
if (rnd_i < 0) {
cls_assgn[i + start_i] = k;
break;
Expand Down Expand Up @@ -217,7 +217,7 @@ void MCMC_state::update_p() {
p[M - 1] = (1 - sum > 0) ? (1 - sum) : 0;

for (size_t i = 0; i < M; i++) {
log_p[i] = std::logf(static_cast<float>(p[i]) + 1e-40f);
log_p[i] = ::logf(static_cast<float>(p[i]) + 1e-40f);
}
}

Expand Down Expand Up @@ -510,14 +510,15 @@ std::unordered_map<std::string, arma::vec> mcmc(
unsigned n_threads,
int opt_llk,
bool verbose,
unsigned int seed
unsigned int seed,
bool random_init
) {

int n_pst = (iter - burn) / thin;

ldmat_data ldmat_dat;

MCMC_state state(data.beta_mrg.size(), M, a0k, b0k, sz, seed);
MCMC_state state(data.beta_mrg.size(), M, a0k, b0k, sz, seed, random_init);

// Deflation correction
for (size_t i = 0; i < data.beta_mrg.size(); i++) {
Expand Down
22 changes: 13 additions & 9 deletions src/sdpr_mcmc.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ std::vector<double> cluster_var;
std::vector<unsigned> suff_stats;
std::vector<double> sumsq;
MCMC_state(size_t num_snp, size_t max_cluster, \
double a0, double b0, double sz, unsigned int seed) {
double a0, double b0, double sz, unsigned int seed,
bool random_init = false) {
a0k = a0; b0k = b0; N = sz;
// Changed May 20 2021
// Now N (sz) is absorbed into A, B; so set to 1.
Expand All @@ -148,14 +149,16 @@ MCMC_state(size_t num_snp, size_t max_cluster, \
suff_stats.assign(max_cluster, 0);
sumsq.assign(max_cluster, 0.0);
V.assign(max_cluster, 0.0);
// Initialize all SNPs to the null cluster (k=0). The original SDPR
// used random initialization (uniform over 0..M-1), but this causes
// the first sample_beta() call to allocate an enormous dense matrix
// (nearly all SNPs are "causal"), crashing with "Mat::init() too large".
// Starting from null is standard MCMC practice and lets the sampler
// discover causal assignments organically.
cls_assgn.assign(num_snp, 0);
r.seed(seed);
if (random_init) {
std::uniform_int_distribution<size_t> init_dist(0, M - 1);
cls_assgn.resize(num_snp);
for (size_t i = 0; i < num_snp; i++) {
cls_assgn[i] = static_cast<int>(init_dist(r));
}
} else {
cls_assgn.assign(num_snp, 0);
}
}

void sample_sigma2();
Expand Down Expand Up @@ -255,5 +258,6 @@ std::unordered_map<std::string, arma::vec> mcmc(
unsigned n_threads,
int opt_llk,
bool verbose,
unsigned int seed
unsigned int seed,
bool random_init = false
);
16 changes: 16 additions & 0 deletions tests/testthat/test_rr_dispatch.R
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,22 @@ test_that("sdpr_weights dispatches to sdpr with correct arguments", {
expect_equal(result, seq_len(p) * 0.02)
})

test_that("sdpr_weights forwards init mode", {
p <- 5
bhat <- rnorm(p, sd = 0.1)
R <- diag(p)
stat <- list(b = bhat, n = rep(321, p))
captured <- new.env(parent = emptyenv())
local_mocked_bindings(
sdpr = function(bhat, LD, n, ...) {
captured$dots <- list(...)
list(beta_est = rep(0, length(bhat)))
}
)
sdpr_weights(stat = stat, LD = R, init = "null", iter = 10, burn = 2)
expect_equal(captured$dots$init, "null")
})

test_that("lassosum_rss_weights dispatches to lassosum_rss once per s value", {
set.seed(42)
p <- 10
Expand Down
40 changes: 40 additions & 0 deletions tests/testthat/test_rr_sdpr.R
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@ test_that("sdpr errors on invalid array values", {
)
})

test_that("sdpr errors on invalid init mode", {
expect_error(
sdpr(bhat = rnorm(5), LD = list(blk1 = diag(5)), n = 100, init = "bogus"),
"should be one of"
)
})

test_that("sdpr runs successfully", {
set.seed(42)
p <- 10
Expand Down Expand Up @@ -75,6 +82,39 @@ test_that("sdpr with valid array parameter", {
expect_true(all(is.finite(result$beta_est)))
})

test_that("sdpr fixed-seed runs are reproducible with n_threads = 1 and random init", {
set.seed(42)
p <- 10
bhat <- rnorm(p, sd = 0.1)
R <- diag(p)
out1 <- sdpr(
bhat = bhat, LD = list(blk1 = R), n = 100,
iter = 50, burn = 10, thin = 2, verbose = FALSE,
seed = 42L, init = "random", n_threads = 1
)
out2 <- sdpr(
bhat = bhat, LD = list(blk1 = R), n = 100,
iter = 50, burn = 10, thin = 2, verbose = FALSE,
seed = 42L, init = "random", n_threads = 1
)
expect_equal(out1$beta_est, out2$beta_est)
expect_equal(out1$h2, out2$h2)
})

test_that("sdpr supports null initialization explicitly", {
set.seed(42)
p <- 10
bhat <- rnorm(p, sd = 0.1)
R <- diag(p)
result <- sdpr(
bhat = bhat, LD = list(blk1 = R), n = 100,
iter = 50, burn = 10, thin = 2, verbose = FALSE,
seed = 42L, init = "null"
)
expect_equal(length(result$beta_est), p)
expect_true(all(is.finite(result$beta_est)))
})

# ---- sdpr signal recovery ----
test_that("sdpr recovers signal direction on simulated genotype data", {
set.seed(2024)
Expand Down