Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 111 additions & 0 deletions src/csde/_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
from typing import List, Optional, Tuple, Union

import numpy as np


class PPIAbstractClass:
def __init__(
self,
inputs_gt: Union[Tuple[np.ndarray, np.ndarray], np.ndarray],
inputs_hat: Union[Tuple[np.ndarray, np.ndarray], np.ndarray],
inputs_unl: Union[Tuple[np.ndarray, np.ndarray], np.ndarray],
lambd_mode: str = "overall",
):
self.inputs_gt = inputs_gt
self.inputs_hat = inputs_hat
self.inputs_unl = inputs_unl

inputs_are_tuples = isinstance(inputs_gt, tuple)
if inputs_are_tuples:
self.n = inputs_gt[0].shape[0]
self.N = inputs_unl[0].shape[0]
else:
self.n = self.inputs_gt.shape[0]
self.N = self.inputs_unl.shape[0]
self.r = float(self.n) / self.N
self.theta = None
self.sigma = None
self.hessian = None
self.v = None
self.lambd_mode = lambd_mode
self.lambd_ = None

def get_asymptotic_distribution(self) -> Tuple[np.ndarray, np.ndarray]:
self.sigma = self.compute_sigma(self.lambd_)
return self.theta, self.sigma

def compute_sigma(self, lambd: Union[float, np.ndarray]) -> np.ndarray:
grad_f_unl = self.grad_fn(self.inputs_unl)
grad_f_hat = self.grad_fn(self.inputs_hat)
grad_f_all = np.vstack([grad_f_hat, grad_f_unl])
grad_f_gt = self.grad_fn(self.inputs_gt)

grad_f_ = grad_f_all - grad_f_all.mean(axis=0)
vf = (lambd**2) * (grad_f_.T @ grad_f_) / (self.n + self.N)
rect_ = grad_f_gt - lambd * grad_f_hat
rect_ = rect_ - rect_.mean(axis=0)
vdelta = (rect_.T @ rect_) / self.n
v = vdelta + (self.r * vf)

hess = self.hessian_fn(self.inputs_gt)
self.hessian = hess
self.v = v
return self._compute_sigma(hess, v, self.n)

@staticmethod
def _compute_sigma(hess: np.ndarray, v: np.ndarray, n: int) -> np.ndarray:
inv_hess = np.linalg.pinv(hess)
sigma_ = inv_hess @ v @ inv_hess
sigma_ = sigma_ / n
return sigma_

def get_lambda(
self,
lambd_0: float = 0.5,
idx_to_optimize: Optional[Union[int, List[int]]] = None,
) -> Union[float, np.ndarray]:
print("get point estimate ...")
self.theta = self.get_pointestimate(lambd_=lambd_0)
print("done")

hess = self.hessian_fn(self.inputs_gt)
inv_hess = np.linalg.pinv(hess)
grad_f_unl = self.grad_fn(self.inputs_unl)
grad_f_hat = self.grad_fn(self.inputs_hat)
grad_f_all = np.vstack([grad_f_hat, grad_f_unl])
grad_f_gt = self.grad_fn(self.inputs_gt)

grad_f_hat_ = grad_f_hat - grad_f_hat.mean(0)
grad_f_gt_ = grad_f_gt - grad_f_gt.mean(0)
cov1 = (grad_f_hat_.T @ grad_f_gt_) / self.n
cov2 = (grad_f_gt_.T @ grad_f_hat_) / self.n

grad_f_ = grad_f_all - grad_f_all.mean(axis=0)
vf = (grad_f_.T @ grad_f_) / (self.n + self.N)
num = inv_hess @ (cov1 + cov2) @ inv_hess
denom = 2 * (1.0 + self.r) * (inv_hess @ vf @ inv_hess)
if self.lambd_mode == "element":
lambd_star = num / denom
return np.diag(lambd_star)
elif idx_to_optimize is not None:
print("optimize lambda for a single theta comp.")
if isinstance(idx_to_optimize, int):
return (
num[idx_to_optimize, idx_to_optimize]
/ denom[idx_to_optimize, idx_to_optimize]
)
else:
return np.trace(num[idx_to_optimize, :][:, idx_to_optimize]) / np.trace(
denom[idx_to_optimize, :][:, idx_to_optimize]
)
else:
return np.trace(num) / np.trace(denom)

def get_pointestimate(self, lambd_: float) -> np.ndarray:
raise NotImplementedError

def grad_fn(self, inputs: Tuple[np.ndarray, np.ndarray]) -> np.ndarray:
raise NotImplementedError

def hessian_fn(self, inputs: Tuple[np.ndarray, np.ndarray]) -> np.ndarray:
raise NotImplementedError
10 changes: 7 additions & 3 deletions src/csde/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
import pandas as pd

from csde.model import InterceptRegression
from csde.model_poisson import PoissonIntercept


def _map_cell_types(
Expand Down Expand Up @@ -45,6 +45,7 @@ def run_csde(
cell_pop_b: str,
gt_key: str,
layer_name: Optional[str] = None,
importance_weights: Optional[np.ndarray] = None,
**model_kwargs,
) -> pd.DataFrame:
"""
Expand All @@ -61,7 +62,9 @@ def run_csde(
cell_pop_b: Name of the second cell population (target group).
gt_key: Boolean column in adata_gt.obs indicating if the prediction is correct.
layer_name: Layer in adata.layers to use for expression counts. If None, uses .X.
**model_kwargs: Additional arguments passed to InterceptRegression (e.g., family, optimizer).
importance_weights: Optional 1-D array of importance weights for the ground-truth
observations. Will be normalized to sum to n_obs internally.
**model_kwargs: Additional arguments passed to PoissonIntercept (e.g., optimizer).

Returns:
DataFrame indexed by gene names with columns:
Expand Down Expand Up @@ -106,10 +109,11 @@ def get_X(adata):
inputs_unl = (X_unl, y_pred_unl)

# inference
model = InterceptRegression(
model = PoissonIntercept(
inputs_gt=inputs_gt,
inputs_hat=inputs_hat,
inputs_unl=inputs_unl,
importance_weights=importance_weights,
**model_kwargs,
)
model.fit(lambd_=None)
Expand Down
Loading
Loading