From 8737f00d5bc1dec916c3ff0e59d3d99459d01590 Mon Sep 17 00:00:00 2001 From: Pierre Boyeau Date: Thu, 23 Apr 2026 15:40:20 -0400 Subject: [PATCH 1/2] Add support for importance weights in run_csde and InterceptRegression. --- src/csde/api.py | 4 ++++ src/csde/model.py | 59 +++++++++++++++++++++++++++++++++------------- tests/test_csde.py | 42 +++++++++++++++++++++++++++++++++ 3 files changed, 88 insertions(+), 17 deletions(-) diff --git a/src/csde/api.py b/src/csde/api.py index edfcae5..f6511b3 100644 --- a/src/csde/api.py +++ b/src/csde/api.py @@ -45,6 +45,7 @@ def run_csde( cell_pop_b: str, gt_key: str, layer_name: Optional[str] = None, + importance_weights: Optional[np.ndarray] = None, **model_kwargs, ) -> pd.DataFrame: """ @@ -61,6 +62,8 @@ def run_csde( cell_pop_b: Name of the second cell population (target group). gt_key: Boolean column in adata_gt.obs indicating if the prediction is correct. layer_name: Layer in adata.layers to use for expression counts. If None, uses .X. + importance_weights: Optional 1-D array of importance weights for the ground-truth + observations. Will be normalized to sum to n_obs internally. **model_kwargs: Additional arguments passed to InterceptRegression (e.g., family, optimizer). Returns: @@ -110,6 +113,7 @@ def get_X(adata): inputs_gt=inputs_gt, inputs_hat=inputs_hat, inputs_unl=inputs_unl, + importance_weights=importance_weights, **model_kwargs, ) model.fit(lambd_=None) diff --git a/src/csde/model.py b/src/csde/model.py index 1979849..24c2b9c 100644 --- a/src/csde/model.py +++ b/src/csde/model.py @@ -176,13 +176,16 @@ def setup(self): "mu", nn.initializers.normal(), (self.n_classes - 1, self.n_features) ) - def __call__(self, x, y): + def __call__(self, x, y, w=None): y_ = y.astype(jnp.int32) mu_placeholder = jnp.zeros_like(self.mu0) mu = jnp.concatenate([mu_placeholder[None], self.mu], axis=0) y_oh = jnp.eye(self.n_classes)[y_] mus_ = y_oh @ mu + self.mu0 + if w is None: + w = jnp.ones_like(y, dtype=jnp.float64) + if self.family == "poisson": rates = jnp.exp(mus_) log_px_c_unsummed = Poisson(rate=rates).log_prob(x) @@ -195,8 +198,8 @@ def __call__(self, x, y): loss = -log_px_c return { - "loss": loss, - "loss_unsummed": -log_px_c_unsummed, + "loss": loss * w, + "loss_unsummed": -log_px_c_unsummed * w[..., None], } @@ -212,6 +215,7 @@ def __init__( optimizer_kwargs: Optional[Dict[str, Any]] = None, family: str = "poisson", jit: bool = True, + importance_weights: Optional[np.ndarray] = None, **kwargs, ): """ @@ -223,6 +227,8 @@ def __init__( optimizer_kwargs: Keyword arguments for the optimizer. family: Distribution family ('poisson' or 'gaussian'). jit: Whether to JIT compile the optimization. + importance_weights: Optional 1-D array of importance weights for the ground-truth + observations. Will be normalized to sum to n_obs. **kwargs: Arguments passed to PPIAbstractClass (inputs_gt, inputs_hat, inputs_unl). """ super().__init__(**kwargs) @@ -239,6 +245,19 @@ def __init__( self.inputs_gt = (x_gt, y_gt) self.inputs_hat = (x_hat, y_hat) self.inputs_unl = (x_unl, y_unl) + + if importance_weights is not None: + if importance_weights.shape != (x_gt.shape[0],): + raise ValueError( + "importance_weights must be a 1-D array with the same length " + "as the number of ground-truth observations" + ) + # Normalize so weights sum to n_obs + w = float(x_gt.shape[0]) * importance_weights / importance_weights.sum() + self.importance_weights = w + else: + self.importance_weights = None + n_obs_real = x_gt.shape[0] self.n_features = x_gt.shape[1] @@ -294,13 +313,13 @@ def get_lambda( self.theta = self.get_pointestimate(lambd_=lambd_0) print("done") - hess = self.hessian_fn(self.inputs_gt) + hess = self.hessian_fn(self.inputs_gt, importance_weights=self.importance_weights) inv_hess = np.linalg.pinv(hess) grad_f_unl = self.grad_fn(self.inputs_unl) - grad_f_hat = self.grad_fn(self.inputs_hat) + grad_f_hat = self.grad_fn(self.inputs_hat, w=self.importance_weights) grad_f_all = np.vstack([grad_f_hat, grad_f_unl]) - grad_f_gt = self.grad_fn(self.inputs_gt) + grad_f_gt = self.grad_fn(self.inputs_gt, w=self.importance_weights) grad_f_hat_ = grad_f_hat - grad_f_hat.mean(0) grad_f_gt_ = grad_f_gt - grad_f_gt.mean(0) @@ -402,21 +421,22 @@ def compute_sigma(self, lambd: Union[float, np.ndarray]) -> np.ndarray: return self._compute_sigma(hess, v, self.n) def grad_fn( - self, inputs: Tuple[np.ndarray, np.ndarray], batch_size: int = 128 + self, inputs: Tuple[np.ndarray, np.ndarray], w: Optional[np.ndarray] = None, batch_size: int = 128 ) -> np.ndarray: x, y = inputs n_obs = x.shape[0] - def likelihood(model_params, x, y): - return self.model.apply(model_params, x, y)["loss"] + def likelihood(model_params, x, y, w=None): + return self.model.apply(model_params, x, y, w=w)["loss"] + score = self.jit(jax.jacfwd(likelihood)) all_grads = np.zeros((n_obs, self.n_params)) for i in tqdm(range(0, n_obs, batch_size), desc="Gradient computation"): x_batch = x[i:i+batch_size] y_batch = y[i:i+batch_size] + w_batch = w[i:i+batch_size] if w is not None else None n_obs_batch = x_batch.shape[0] - score = self.jit(jax.jacfwd(likelihood)) - grads = score(self.model_params, x_batch, y_batch) + grads = score(self.model_params, x_batch, y_batch, w=w_batch) grad_mu = np.array(grads["params"]["mu"].reshape(n_obs_batch, -1)) grad_mu0 = np.array(grads["params"]["mu0"].reshape(n_obs_batch, -1)) all_grads[i:i+batch_size] = np.hstack([grad_mu, grad_mu0]) @@ -541,7 +561,7 @@ def zero_init(self): self.model_params = params def hessian_fn( - self, inputs: Tuple[np.ndarray, np.ndarray], device=None + self, inputs: Tuple[np.ndarray, np.ndarray], importance_weights: Optional[np.ndarray] = None, device=None ) -> np.ndarray: x, y = inputs @@ -552,13 +572,13 @@ def hessian_fn( obs_ids = np.arange(n_obs) model_ = self.model - def likelihood(model_params, x, y): - return model_.apply(model_params, x, y)["loss"] + def likelihood(model_params, x, y, w=None): + return model_.apply(model_params, x, y, w=w)["loss"] hess_fn = jax.hessian(likelihood) - def process_hess(x, y): - hess_ = hess_fn(model_params_, x, y) + def process_hess(x, y, w=None): + hess_ = hess_fn(model_params_, x, y, w=w) mu_mu = ( hess_["params"]["mu"]["params"]["mu"] .mean(0) @@ -591,6 +611,11 @@ def process_hess(x, y): y_ = jnp.array(y[[obs_id]], dtype=jnp.int32) x_obs = jax.device_put(x_, device) y_obs = jax.device_put(y_, device) - hess_ = process_hess(x_obs, y_obs) + if importance_weights is not None: + w_ = jnp.array(importance_weights[[obs_id]], dtype=jnp.float64) + w_obs = jax.device_put(w_, device) + else: + w_obs = None + hess_ = process_hess(x_obs, y_obs, w_obs) hessian += hess_ / float(n_obs) return hessian diff --git a/tests/test_csde.py b/tests/test_csde.py index 54c06d7..a3538b9 100644 --- a/tests/test_csde.py +++ b/tests/test_csde.py @@ -53,6 +53,48 @@ def test_run_csde(self): ) self.assertTrue(not res.isnull().values.any()) + def test_run_csde_with_importance_weights(self): + n_gt = len(self.adata_gt) + rng = np.random.default_rng(0) + importance_weights = rng.uniform(0.5, 2.0, size=n_gt) + + res = run_csde( + adata_pred=self.adata_pred, + adata_gt=self.adata_gt, + pred_cell_pop_key="cell_type", + cell_pop_a="TypeA", + cell_pop_b="TypeB", + gt_key="is_correct", + optimizer="gd", + optimizer_kwargs={"n_iter": 10}, + importance_weights=importance_weights, + ) + + self.assertIsInstance(res, pd.DataFrame) + self.assertEqual(len(res), 10) + self.assertListEqual( + list(res.columns), ["log_fold_change", "p_value", "p_value_adj"] + ) + self.assertTrue(not res.isnull().values.any()) + + def test_importance_weights_wrong_shape(self): + from csde.model import InterceptRegression + + x_gt, y_gt = self.adata_gt.X.astype(float), np.zeros(len(self.adata_gt), dtype=int) + x_hat = x_gt.copy() + x_unl = self.adata_pred.X.astype(float) + y_hat = np.zeros(len(self.adata_gt), dtype=int) + y_unl = np.zeros(len(self.adata_pred), dtype=int) + + bad_weights = np.ones(len(self.adata_gt) + 5) + with self.assertRaises(ValueError): + InterceptRegression( + inputs_gt=(x_gt, y_gt), + inputs_hat=(x_hat, y_hat), + inputs_unl=(x_unl, y_unl), + importance_weights=bad_weights, + ) + if __name__ == "__main__": unittest.main() From eebeeb7ee392ab6165e8eb1eb148eb983ead70a1 Mon Sep 17 00:00:00 2001 From: Pierre Boyeau Date: Fri, 24 Apr 2026 16:49:46 -0400 Subject: [PATCH 2/2] optimized training + importance weights --- src/csde/_base.py | 111 ++++++++++ src/csde/api.py | 6 +- src/csde/{model.py => model_poisson.py} | 270 ++++-------------------- src/csde/optimization.py | 37 +++- tests/test_csde.py | 10 +- 5 files changed, 183 insertions(+), 251 deletions(-) create mode 100644 src/csde/_base.py rename src/csde/{model.py => model_poisson.py} (60%) diff --git a/src/csde/_base.py b/src/csde/_base.py new file mode 100644 index 0000000..bc5e322 --- /dev/null +++ b/src/csde/_base.py @@ -0,0 +1,111 @@ +from typing import List, Optional, Tuple, Union + +import numpy as np + + +class PPIAbstractClass: + def __init__( + self, + inputs_gt: Union[Tuple[np.ndarray, np.ndarray], np.ndarray], + inputs_hat: Union[Tuple[np.ndarray, np.ndarray], np.ndarray], + inputs_unl: Union[Tuple[np.ndarray, np.ndarray], np.ndarray], + lambd_mode: str = "overall", + ): + self.inputs_gt = inputs_gt + self.inputs_hat = inputs_hat + self.inputs_unl = inputs_unl + + inputs_are_tuples = isinstance(inputs_gt, tuple) + if inputs_are_tuples: + self.n = inputs_gt[0].shape[0] + self.N = inputs_unl[0].shape[0] + else: + self.n = self.inputs_gt.shape[0] + self.N = self.inputs_unl.shape[0] + self.r = float(self.n) / self.N + self.theta = None + self.sigma = None + self.hessian = None + self.v = None + self.lambd_mode = lambd_mode + self.lambd_ = None + + def get_asymptotic_distribution(self) -> Tuple[np.ndarray, np.ndarray]: + self.sigma = self.compute_sigma(self.lambd_) + return self.theta, self.sigma + + def compute_sigma(self, lambd: Union[float, np.ndarray]) -> np.ndarray: + grad_f_unl = self.grad_fn(self.inputs_unl) + grad_f_hat = self.grad_fn(self.inputs_hat) + grad_f_all = np.vstack([grad_f_hat, grad_f_unl]) + grad_f_gt = self.grad_fn(self.inputs_gt) + + grad_f_ = grad_f_all - grad_f_all.mean(axis=0) + vf = (lambd**2) * (grad_f_.T @ grad_f_) / (self.n + self.N) + rect_ = grad_f_gt - lambd * grad_f_hat + rect_ = rect_ - rect_.mean(axis=0) + vdelta = (rect_.T @ rect_) / self.n + v = vdelta + (self.r * vf) + + hess = self.hessian_fn(self.inputs_gt) + self.hessian = hess + self.v = v + return self._compute_sigma(hess, v, self.n) + + @staticmethod + def _compute_sigma(hess: np.ndarray, v: np.ndarray, n: int) -> np.ndarray: + inv_hess = np.linalg.pinv(hess) + sigma_ = inv_hess @ v @ inv_hess + sigma_ = sigma_ / n + return sigma_ + + def get_lambda( + self, + lambd_0: float = 0.5, + idx_to_optimize: Optional[Union[int, List[int]]] = None, + ) -> Union[float, np.ndarray]: + print("get point estimate ...") + self.theta = self.get_pointestimate(lambd_=lambd_0) + print("done") + + hess = self.hessian_fn(self.inputs_gt) + inv_hess = np.linalg.pinv(hess) + grad_f_unl = self.grad_fn(self.inputs_unl) + grad_f_hat = self.grad_fn(self.inputs_hat) + grad_f_all = np.vstack([grad_f_hat, grad_f_unl]) + grad_f_gt = self.grad_fn(self.inputs_gt) + + grad_f_hat_ = grad_f_hat - grad_f_hat.mean(0) + grad_f_gt_ = grad_f_gt - grad_f_gt.mean(0) + cov1 = (grad_f_hat_.T @ grad_f_gt_) / self.n + cov2 = (grad_f_gt_.T @ grad_f_hat_) / self.n + + grad_f_ = grad_f_all - grad_f_all.mean(axis=0) + vf = (grad_f_.T @ grad_f_) / (self.n + self.N) + num = inv_hess @ (cov1 + cov2) @ inv_hess + denom = 2 * (1.0 + self.r) * (inv_hess @ vf @ inv_hess) + if self.lambd_mode == "element": + lambd_star = num / denom + return np.diag(lambd_star) + elif idx_to_optimize is not None: + print("optimize lambda for a single theta comp.") + if isinstance(idx_to_optimize, int): + return ( + num[idx_to_optimize, idx_to_optimize] + / denom[idx_to_optimize, idx_to_optimize] + ) + else: + return np.trace(num[idx_to_optimize, :][:, idx_to_optimize]) / np.trace( + denom[idx_to_optimize, :][:, idx_to_optimize] + ) + else: + return np.trace(num) / np.trace(denom) + + def get_pointestimate(self, lambd_: float) -> np.ndarray: + raise NotImplementedError + + def grad_fn(self, inputs: Tuple[np.ndarray, np.ndarray]) -> np.ndarray: + raise NotImplementedError + + def hessian_fn(self, inputs: Tuple[np.ndarray, np.ndarray]) -> np.ndarray: + raise NotImplementedError diff --git a/src/csde/api.py b/src/csde/api.py index f6511b3..b150c6c 100644 --- a/src/csde/api.py +++ b/src/csde/api.py @@ -4,7 +4,7 @@ import numpy as np import pandas as pd -from csde.model import InterceptRegression +from csde.model_poisson import PoissonIntercept def _map_cell_types( @@ -64,7 +64,7 @@ def run_csde( layer_name: Layer in adata.layers to use for expression counts. If None, uses .X. importance_weights: Optional 1-D array of importance weights for the ground-truth observations. Will be normalized to sum to n_obs internally. - **model_kwargs: Additional arguments passed to InterceptRegression (e.g., family, optimizer). + **model_kwargs: Additional arguments passed to PoissonIntercept (e.g., optimizer). Returns: DataFrame indexed by gene names with columns: @@ -109,7 +109,7 @@ def get_X(adata): inputs_unl = (X_unl, y_pred_unl) # inference - model = InterceptRegression( + model = PoissonIntercept( inputs_gt=inputs_gt, inputs_hat=inputs_hat, inputs_unl=inputs_unl, diff --git a/src/csde/model.py b/src/csde/model_poisson.py similarity index 60% rename from src/csde/model.py rename to src/csde/model_poisson.py index 24c2b9c..8d88434 100644 --- a/src/csde/model.py +++ b/src/csde/model_poisson.py @@ -5,170 +5,21 @@ import jax.numpy as jnp import numpy as np import pandas as pd -from numpyro.distributions import Normal, Poisson +from numpyro.distributions import Poisson from statsmodels.stats.multitest import multipletests from tqdm import tqdm +from csde._base import PPIAbstractClass from csde.optimization import _zstat_generic2, optimize_ppi, optimize_ppi_gd -jax.config.update("jax_enable_x64", True) +jax.config.update("jax_enable_x64", False) -class PPIAbstractClass: - """ - Abstract base class for Prediction-Powered Inference (PPI) models. - """ - - def __init__( - self, - inputs_gt: Union[Tuple[np.ndarray, np.ndarray], np.ndarray], - inputs_hat: Union[Tuple[np.ndarray, np.ndarray], np.ndarray], - inputs_unl: Union[Tuple[np.ndarray, np.ndarray], np.ndarray], - lambd_mode: str = "overall", - ): - """ - Initialize the PPI model. - - Args: - inputs_gt: Ground truth data (features, labels). - inputs_hat: Predicted data for the labeled set (features, predicted labels). - inputs_unl: Unlabeled data (features, predicted labels). - lambd_mode: Mode for lambda parameter ('overall' or 'element'). - """ - self.inputs_gt = inputs_gt - self.inputs_hat = inputs_hat - self.inputs_unl = inputs_unl - - inputs_are_tuples = isinstance(inputs_gt, tuple) - if inputs_are_tuples: - self.n = inputs_gt[0].shape[0] - self.N = inputs_unl[0].shape[0] - else: - self.n = self.inputs_gt.shape[0] - self.N = self.inputs_unl.shape[0] - self.r = float(self.n) / self.N - self.theta = None - self.sigma = None - self.hessian = None - self.v = None - self.lambd_mode = lambd_mode - self.lambd_ = None - - def get_asymptotic_distribution(self) -> Tuple[np.ndarray, np.ndarray]: - """ - Compute the asymptotic distribution of the estimator. - - Returns: - Tuple containing the point estimate (theta) and the covariance matrix (sigma). - """ - self.sigma = self.compute_sigma(self.lambd_) - return self.theta, self.sigma - - def compute_sigma(self, lambd: Union[float, np.ndarray]) -> np.ndarray: - """ - Compute the covariance matrix of the estimator. - """ - grad_f_unl = self.grad_fn(self.inputs_unl) - grad_f_hat = self.grad_fn(self.inputs_hat) - grad_f_all = np.vstack([grad_f_hat, grad_f_unl]) - grad_f_gt = self.grad_fn(self.inputs_gt) - - grad_f_ = grad_f_all - grad_f_all.mean(axis=0) - - # Handle lambda broadcasting if necessary - if self.lambd_mode == "element" and isinstance(lambd, np.ndarray): - # This part was implementation specific in subclass, but generalized here based on pattern - # Assuming lambd matches gradient dimensions or is handled in subclass override - pass - - # Base implementation for scalar lambda, override in subclass if needed - vf = (lambd**2) * (grad_f_.T @ grad_f_) / (self.n + self.N) - rect_ = grad_f_gt - lambd * grad_f_hat - rect_ = rect_ - rect_.mean(axis=0) - vdelta = (rect_.T @ rect_) / self.n - v = vdelta + (self.r * vf) - - hess = self.hessian_fn(self.inputs_gt) - self.hessian = hess - self.v = v - return self._compute_sigma(hess, v, self.n) - - @staticmethod - def _compute_sigma(hess: np.ndarray, v: np.ndarray, n: int) -> np.ndarray: - """ - Compute the asymptotic covariance matrix of the parameter estimates. - """ - inv_hess = np.linalg.pinv(hess) - sigma_ = inv_hess @ v @ inv_hess - sigma_ = sigma_ / n - return sigma_ - - def get_lambda( - self, - lambd_0: float = 0.5, - idx_to_optimize: Optional[Union[int, List[int]]] = None, - ) -> Union[float, np.ndarray]: - """ - Estimate the optimal lambda parameter. - """ - print("get point estimate ...") - self.theta = self.get_pointestimate(lambd_=lambd_0) - print("done") - - hess = self.hessian_fn(self.inputs_gt) - - inv_hess = np.linalg.pinv(hess) - grad_f_unl = self.grad_fn(self.inputs_unl) - grad_f_hat = self.grad_fn(self.inputs_hat) - grad_f_all = np.vstack([grad_f_hat, grad_f_unl]) - grad_f_gt = self.grad_fn(self.inputs_gt) - - grad_f_hat_ = grad_f_hat - grad_f_hat.mean(0) - grad_f_gt_ = grad_f_gt - grad_f_gt.mean(0) - cov1 = (grad_f_hat_.T @ grad_f_gt_) / self.n - cov2 = (grad_f_gt_.T @ grad_f_hat_) / self.n - - grad_f_ = grad_f_all - grad_f_all.mean(axis=0) - vf = (grad_f_.T @ grad_f_) / (self.n + self.N) - num = inv_hess @ (cov1 + cov2) @ inv_hess - denom = 2 * (1.0 + self.r) * (inv_hess @ vf @ inv_hess) - if self.lambd_mode == "element": - lambd_star = num / denom - return np.diag(lambd_star) - elif idx_to_optimize is not None: - print("optimize lambda for a single theta comp.") - if isinstance(idx_to_optimize, int): - return ( - num[idx_to_optimize, idx_to_optimize] - / denom[idx_to_optimize, idx_to_optimize] - ) - else: - return np.trace(num[idx_to_optimize, :][:, idx_to_optimize]) / np.trace( - denom[idx_to_optimize, :][:, idx_to_optimize] - ) - else: - return np.trace(num) / np.trace(denom) - - def get_pointestimate(self, lambd_: float) -> np.ndarray: - raise NotImplementedError - - def grad_fn(self, inputs: Tuple[np.ndarray, np.ndarray]) -> np.ndarray: - raise NotImplementedError - - def hessian_fn(self, inputs: Tuple[np.ndarray, np.ndarray]) -> np.ndarray: - raise NotImplementedError - - -class RegressionInterceptModel(nn.Module): - """ - Flax module for the intercept regression model. - """ - +class PoissonInterceptModule(nn.Module): n_classes: int n_features: int mu_prior_std: Union[float, jnp.ndarray] n_obs_real: int - family: str def setup(self): self.mu0 = self.param("mu0", nn.initializers.normal(), (self.n_features)) @@ -186,15 +37,9 @@ def __call__(self, x, y, w=None): if w is None: w = jnp.ones_like(y, dtype=jnp.float64) - if self.family == "poisson": - rates = jnp.exp(mus_) - log_px_c_unsummed = Poisson(rate=rates).log_prob(x) - log_px_c = log_px_c_unsummed.sum(axis=-1) - elif self.family == "gaussian": - log_px_c_unsummed = Normal(loc=mus_, scale=1.0).log_prob(x) - log_px_c = log_px_c_unsummed.sum(axis=-1) - else: - raise ValueError(f"Unknown family: {self.family}") + rates = jnp.exp(mus_) + log_px_c_unsummed = Poisson(rate=rates).log_prob(x) + log_px_c = log_px_c_unsummed.sum(axis=-1) loss = -log_px_c return { @@ -203,34 +48,16 @@ def __call__(self, x, y, w=None): } -class InterceptRegression(PPIAbstractClass): - """ - Intercept Regression model for spatial differential expression analysis. - """ - +class PoissonIntercept(PPIAbstractClass): def __init__( self, mu_prior_std: Optional[Union[float, jnp.ndarray]] = None, optimizer: str = "gd", optimizer_kwargs: Optional[Dict[str, Any]] = None, - family: str = "poisson", jit: bool = True, importance_weights: Optional[np.ndarray] = None, **kwargs, ): - """ - Initialize the InterceptRegression model. - - Args: - mu_prior_std: Prior standard deviation for mu. - optimizer: Optimization method ('gd' or 'lbfgs'). - optimizer_kwargs: Keyword arguments for the optimizer. - family: Distribution family ('poisson' or 'gaussian'). - jit: Whether to JIT compile the optimization. - importance_weights: Optional 1-D array of importance weights for the ground-truth - observations. Will be normalized to sum to n_obs. - **kwargs: Arguments passed to PPIAbstractClass (inputs_gt, inputs_hat, inputs_unl). - """ super().__init__(**kwargs) x_gt, y_gt = self.inputs_gt @@ -252,7 +79,6 @@ def __init__( "importance_weights must be a 1-D array with the same length " "as the number of ground-truth observations" ) - # Normalize so weights sum to n_obs w = float(x_gt.shape[0]) * importance_weights / importance_weights.sum() self.importance_weights = w else: @@ -262,12 +88,11 @@ def __init__( self.n_features = x_gt.shape[1] self.n_params = (self.n_classes - 1) * self.n_features + self.n_features - self.model = RegressionInterceptModel( + self.model = PoissonInterceptModule( n_classes=self.n_classes, n_features=self.n_features, mu_prior_std=mu_prior_std, n_obs_real=n_obs_real, - family=family, ) self.model_params = None @@ -281,13 +106,6 @@ def __init__( def fit( self, lambd_: Optional[Union[float, np.ndarray]] = None, refit: bool = False ): - """ - Fit the model parameters. - - Args: - lambd_: Lambda parameter. If None, it is estimated. - refit: Whether to re-initialize parameters before fitting. - """ if lambd_ is None: lambd_ = self.get_lambda() print(f"lambda: {lambd_}") @@ -301,20 +119,13 @@ def get_lambda( lambd_0: float = 0.5, idx_to_optimize: Optional[Union[int, List[int]]] = None, ) -> Union[float, np.ndarray]: - """ - Estimate the optimal lambda parameter. - Overriding parent method to handle element-wise lambda specific logic. - """ - # Call parent to get num and denom matrices/values if needed, but the parent implementation - # might need access to _construct_contrast for element-wise which is specific to this class. - # So copying logic from original implementation to be safe and consistent. - print("get point estimate ...") self.theta = self.get_pointestimate(lambd_=lambd_0) print("done") - hess = self.hessian_fn(self.inputs_gt, importance_weights=self.importance_weights) - + hess = self.hessian_fn( + self.inputs_gt, importance_weights=self.importance_weights + ) inv_hess = np.linalg.pinv(hess) grad_f_unl = self.grad_fn(self.inputs_unl) grad_f_hat = self.grad_fn(self.inputs_hat, w=self.importance_weights) @@ -361,6 +172,7 @@ def get_pointestimate(self, lambd_: Union[float, np.ndarray]) -> np.ndarray: model_params0 = self.model_params if self.model_params is not None else None if self.optimizer == "lbfgs": + print("optimize with lbfgs") model_params = optimize_ppi( self.model, lambd_=lambd_, @@ -374,6 +186,7 @@ def get_pointestimate(self, lambd_: Union[float, np.ndarray]) -> np.ndarray: **self.optimizer_kwargs, ) elif self.optimizer == "gd": + print("optimize with gd") model_params = optimize_ppi_gd( self.model, lambd_=lambd_, @@ -396,11 +209,10 @@ def get_pointestimate(self, lambd_: Union[float, np.ndarray]) -> np.ndarray: return np.hstack([mu, mu0]) def compute_sigma(self, lambd: Union[float, np.ndarray]) -> np.ndarray: - # Override to handle element-wise lambda and specific broadcasting grad_f_unl = self.grad_fn(self.inputs_unl) - grad_f_hat = self.grad_fn(self.inputs_hat) + grad_f_hat = self.grad_fn(self.inputs_hat, w=self.importance_weights) grad_f_all = np.vstack([grad_f_hat, grad_f_unl]) - grad_f_gt = self.grad_fn(self.inputs_gt) + grad_f_gt = self.grad_fn(self.inputs_gt, w=self.importance_weights) grad_f_ = grad_f_all - grad_f_all.mean(axis=0) if self.lambd_mode == "element": @@ -415,13 +227,18 @@ def compute_sigma(self, lambd: Union[float, np.ndarray]) -> np.ndarray: vdelta = (rect_.T @ rect_) / self.n v = vdelta + (self.r * vf) - hess = self.hessian_fn(self.inputs_gt) + hess = self.hessian_fn( + self.inputs_gt, importance_weights=self.importance_weights + ) self.hessian = hess self.v = v return self._compute_sigma(hess, v, self.n) def grad_fn( - self, inputs: Tuple[np.ndarray, np.ndarray], w: Optional[np.ndarray] = None, batch_size: int = 128 + self, + inputs: Tuple[np.ndarray, np.ndarray], + w: Optional[np.ndarray] = None, + batch_size: int = 128, ) -> np.ndarray: x, y = inputs n_obs = x.shape[0] @@ -432,20 +249,19 @@ def likelihood(model_params, x, y, w=None): score = self.jit(jax.jacfwd(likelihood)) all_grads = np.zeros((n_obs, self.n_params)) for i in tqdm(range(0, n_obs, batch_size), desc="Gradient computation"): - x_batch = x[i:i+batch_size] - y_batch = y[i:i+batch_size] - w_batch = w[i:i+batch_size] if w is not None else None + x_batch = x[i : i + batch_size] + y_batch = y[i : i + batch_size] + w_batch = w[i : i + batch_size] if w is not None else None n_obs_batch = x_batch.shape[0] grads = score(self.model_params, x_batch, y_batch, w=w_batch) grad_mu = np.array(grads["params"]["mu"].reshape(n_obs_batch, -1)) grad_mu0 = np.array(grads["params"]["mu0"].reshape(n_obs_batch, -1)) - all_grads[i:i+batch_size] = np.hstack([grad_mu, grad_mu0]) + all_grads[i : i + batch_size] = np.hstack([grad_mu, grad_mu0]) return np.array(all_grads) def _construct_contrast(self, feature_id: int, idx_a: int) -> np.ndarray: mu_contrast = np.zeros((self.n_classes - 1, self.n_features)) mu_contrast[idx_a - 1, feature_id] = 1.0 - mu0_contrast = np.zeros(self.n_features) contrast = np.hstack([mu_contrast.flatten(), mu0_contrast]) return contrast.astype(int) @@ -453,7 +269,6 @@ def _construct_contrast(self, feature_id: int, idx_a: int) -> np.ndarray: def idx_to_feat(self) -> np.ndarray: mu_identifier = np.ones((self.n_classes - 1, self.n_features)) mu_identifier = mu_identifier * np.arange(self.n_features) - mu0_identifier = np.arange(self.n_features) identifier = np.hstack([mu_identifier.flatten(), mu0_identifier]) return identifier.astype(int) @@ -463,13 +278,11 @@ def construct_contrast(self, idx_a: int) -> np.ndarray: self._construct_contrast(feature_id, idx_a) for feature_id in range(self.n_features) ] - _contrast = np.vstack(_contrast) - return _contrast + return np.vstack(_contrast) def get_beta(self, idx_a: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: if idx_a == 0: raise ValueError("`class_a` cannot be the reference class.") - contrast = self.construct_contrast(idx_a) beta = contrast @ self.theta cov = contrast @ self.sigma @ contrast.T @@ -493,8 +306,7 @@ def _get_param_mask(self, feature_id: int) -> np.ndarray: for class_id in range(1, self.n_classes) ] mu0_indices = [self._get_param_id(feature_id=feature_id, param_type="mu0")] - indices_to_keep = np.hstack([mu_indices, mu0_indices]) - return indices_to_keep + return np.hstack([mu_indices, mu0_indices]) def test_differential_expression( self, @@ -502,17 +314,6 @@ def test_differential_expression( feature_names: Optional[List[str]] = None, cond_thresh: float = np.inf, ) -> pd.DataFrame: - """ - Perform differential expression testing. - - Args: - idx_a: The index of the target class (1 or 2). Note: 0 is the reference class. - feature_names: List of feature names. - cond_thresh: Condition number threshold for Hessian. - - Returns: - DataFrame containing the results (p-values, log-fold changes, etc.). - """ idx_a_ = idx_a - 1 results = [] for feature_id in range(self.n_features): @@ -542,7 +343,7 @@ def test_differential_expression( } ) res = pd.DataFrame(results) - res["pval"].iloc[np.isnan(res["pval"])] = 1.0 + res.loc[np.isnan(res["pval"]), "pval"] = 1.0 res["padj"] = multipletests(res["pval"], method="fdr_bh")[1] res["is_significant_005"] = res["padj"] < 0.05 if feature_names is not None: @@ -561,7 +362,10 @@ def zero_init(self): self.model_params = params def hessian_fn( - self, inputs: Tuple[np.ndarray, np.ndarray], importance_weights: Optional[np.ndarray] = None, device=None + self, + inputs: Tuple[np.ndarray, np.ndarray], + importance_weights: Optional[np.ndarray] = None, + device=None, ) -> np.ndarray: x, y = inputs @@ -597,13 +401,12 @@ def process_hess(x, y, w=None): .mean(0) .reshape(self.n_features, self.n_features) ) - blk = jnp.block( + return jnp.block( [ [mu_mu, mu_mu0], [mu_mu0.T, mu0_mu0], ] ) - return blk hessian = np.zeros((self.n_params, self.n_params), dtype=np.float64) for obs_id in tqdm(obs_ids, desc="Hessian computation"): @@ -616,6 +419,5 @@ def process_hess(x, y, w=None): w_obs = jax.device_put(w_, device) else: w_obs = None - hess_ = process_hess(x_obs, y_obs, w_obs) - hessian += hess_ / float(n_obs) + hessian += process_hess(x_obs, y_obs, w_obs) / float(n_obs) return hessian diff --git a/src/csde/optimization.py b/src/csde/optimization.py index 9269d1b..e7ffa02 100644 --- a/src/csde/optimization.py +++ b/src/csde/optimization.py @@ -102,6 +102,7 @@ def optimize_ppi_gd( y_hat: jnp.ndarray, x_unl: jnp.ndarray, y_unl: jnp.ndarray, + w: Optional[jnp.ndarray] = None, model_params0: Optional[Any] = None, lambd_: float = 1.0, tol: float = 1e-3, @@ -123,6 +124,8 @@ def optimize_ppi_gd( y_hat = jax.device_put(jnp.array(y_hat, dtype=jnp.int32)) x_unl = jax.device_put(jnp.array(x_unl, dtype=jnp.float64)) y_unl = jax.device_put(jnp.array(y_unl, dtype=jnp.int32)) + if w is not None: + w = jax.device_put(jnp.array(w, dtype=jnp.float64)) x0 = jnp.ones((32, x_gt.shape[1]), dtype=jnp.float64) y0 = jnp.ones(32, dtype=jnp.int32) @@ -142,23 +145,35 @@ def optimize_ppi_gd( lambd_ = jax.device_put(lambd_) - def loss_fn(zetas): - loss_gt = model.apply(zetas, x_gt, y_gt)["loss_unsummed"].mean(0) - loss_hat = model.apply(zetas, x_hat, y_hat)["loss_unsummed"].mean(0) - loss_unl = model.apply(zetas, x_unl, y_unl)["loss_unsummed"].mean(0) - loss = (lambd_ * loss_unl) - (lambd_ * loss_hat) + loss_gt - loss = loss.sum(-1) - return loss + def step_fn(theta_, opt_state_, x_gt_, y_gt_, x_hat_, y_hat_, x_unl_, y_unl_): + def loss_fn(zetas): + loss_gt = model.apply(zetas, x_gt_, y_gt_, w=w)["loss_unsummed"].mean(0) + loss_hat = model.apply(zetas, x_hat_, y_hat_, w=w)["loss_unsummed"].mean(0) + loss_unl = model.apply(zetas, x_unl_, y_unl_)["loss_unsummed"].mean(0) + loss = (lambd_ * loss_unl) - (lambd_ * loss_hat) + loss_gt + loss = loss.sum(-1) + + # loss_gt = model.apply(zetas, x_gt_, y_gt_, w=w)["loss"].mean() + # loss_hat = model.apply(zetas, x_hat_, y_hat_, w=w)["loss"].mean() + # loss_unl = model.apply(zetas, x_unl_, y_unl_)["loss"].mean() + # loss = (lambd_ * loss_unl) - (lambd_ * loss_hat) + loss_gt + return loss + + loss, grad = jax.value_and_grad(loss_fn)(theta_) + updates, opt_state_ = opt.update(grad, opt_state_, theta_) + theta_ = optax.apply_updates(theta_, updates) + return theta_, opt_state_, loss + + compiled_step = jitter(step_fn) - value_and_grad_fn = jitter(jax.value_and_grad(loss_fn)) previous_loss = 1e6 print("lambda:", lambd_) print("tol:", tol_) pbar = trange(n_iter) for _ in pbar: - loss, grad = value_and_grad_fn(theta) - updates, opt_state = opt.update(grad, opt_state, theta) - theta = optax.apply_updates(theta, updates) + theta, opt_state, loss = compiled_step( + theta, opt_state, x_gt, y_gt, x_hat, y_hat, x_unl, y_unl + ) stopping_criterion = np.abs(loss - previous_loss) if np.allclose(loss, previous_loss, atol=tol_, rtol=0): diff --git a/tests/test_csde.py b/tests/test_csde.py index a3538b9..83e6471 100644 --- a/tests/test_csde.py +++ b/tests/test_csde.py @@ -1,7 +1,9 @@ import unittest + +import anndata import numpy as np import pandas as pd -import anndata + from csde import run_csde @@ -78,9 +80,11 @@ def test_run_csde_with_importance_weights(self): self.assertTrue(not res.isnull().values.any()) def test_importance_weights_wrong_shape(self): - from csde.model import InterceptRegression + from csde.model_poisson import PoissonIntercept as InterceptRegression - x_gt, y_gt = self.adata_gt.X.astype(float), np.zeros(len(self.adata_gt), dtype=int) + x_gt, y_gt = self.adata_gt.X.astype(float), np.zeros( + len(self.adata_gt), dtype=int + ) x_hat = x_gt.copy() x_unl = self.adata_pred.X.astype(float) y_hat = np.zeros(len(self.adata_gt), dtype=int)