From d5058c541e7e734a0bffadaa9c7896631907ef08 Mon Sep 17 00:00:00 2001 From: Raymond Lim Date: Thu, 7 May 2026 13:50:36 -0400 Subject: [PATCH 1/6] fix: upcast float16/bfloat16 features to float32 in aggregate_sample_features Same raw h5 read pattern as merge_annotation_features (line 142). After np.array(h5["features"]), cast bfloat16 (|V2 opaque void) and float16 to float32 before passing to downstream aggregation. - bfloat16 stored as |V2 via ml_dtypes: view + astype(float32) - float16: direct astype(float32) Tests added: - test_float16_features_upcast_to_float32 - test_bfloat16_features_upcast_to_float32 All 17 tests pass. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- mussel/cli/aggregate_sample_features.py | 13 ++- .../cli/test_aggregate_sample_features.py | 87 +++++++++++++++++-- 2 files changed, 94 insertions(+), 6 deletions(-) diff --git a/mussel/cli/aggregate_sample_features.py b/mussel/cli/aggregate_sample_features.py index 0963aabe..506706e8 100644 --- a/mussel/cli/aggregate_sample_features.py +++ b/mussel/cli/aggregate_sample_features.py @@ -139,7 +139,18 @@ def main(cfg: AggregateSampleFeaturesConfig): coords_list = [] for i in indices: with h5py.File(patch_features_h5_paths[i], "r") as h5: - features_list.append(np.array(h5["features"])) + features_arr = np.array(h5["features"]) + # HDF5 stores bfloat16 as |V2 opaque void (via ml_dtypes). Cast both + # bfloat16 and float16 to float32 so downstream consumers work correctly. + if features_arr.dtype.kind == "V" and features_arr.dtype.itemsize == 2: + import ml_dtypes # noqa: PLC0415 + + features_arr = features_arr.view(ml_dtypes.bfloat16).astype( + np.float32 + ) + elif features_arr.dtype == np.float16: + features_arr = features_arr.astype(np.float32) + features_list.append(features_arr) coords_list.append(h5["coords"][:]) result = aggregate_sample_features( diff --git a/tests/mussel/cli/test_aggregate_sample_features.py b/tests/mussel/cli/test_aggregate_sample_features.py index ad29cf1e..69a94f07 100644 --- a/tests/mussel/cli/test_aggregate_sample_features.py +++ b/tests/mussel/cli/test_aggregate_sample_features.py @@ -13,10 +13,10 @@ def _make_data(n_tiles, dim=4): return features, coords -def _write_fake_h5(path, n_tiles, dim=4, seed=0): +def _write_fake_h5(path, n_tiles, dim=4, seed=0, feature_dtype=np.float32): """Write a synthetic feature H5 with 'features' and 'coords' datasets.""" rng = np.random.default_rng(seed) - features = rng.random((n_tiles, dim)).astype(np.float32) + features = rng.random((n_tiles, dim)).astype(feature_dtype) coords = rng.integers(0, 1000, (n_tiles, 2)) with h5py.File(path, "w") as f: f.create_dataset("features", data=features) @@ -120,8 +120,9 @@ def test_subsample_tiles_invalid_strategy(): # ============================================================================= -from mussel.utils.feature_extract import \ - aggregate_sample_features as _aggregate_sample_features +from mussel.utils.feature_extract import ( + aggregate_sample_features as _aggregate_sample_features, +) def test_aggregate_sample_features_invalid_shapes(): @@ -216,7 +217,10 @@ def test_aggregate_sample_features_save_pt_false(tmp_path): def test_aggregate_sample_features_two_samples(tmp_path): """Three slides, two samples — two entries in result.""" rng = np.random.default_rng(0) - slides = [(rng.random((10, 4)).astype(np.float32), rng.integers(0, 1000, (10, 2))) for _ in range(3)] + slides = [ + (rng.random((10, 4)).astype(np.float32), rng.integers(0, 1000, (10, 2))) + for _ in range(3) + ] features_list = [f for f, _ in slides] coords_list = [c for _, c in slides] @@ -315,3 +319,76 @@ def test_cli_mismatched_lengths_raises(tmp_path): ) with pytest.raises((ValueError, Exception)): mussel.cli.aggregate_sample_features.main(OmegaConf.structured(cfg)) + + +# ============================================================================= +# Precision upcast tests +# ============================================================================= + + +def test_float16_features_upcast_to_float32(tmp_path): + """float16 features stored in h5 are upcast to float32 in the output H5.""" + n_tiles, dim = 6, 8 + rng = np.random.default_rng(0) + orig_f32 = rng.random((n_tiles, dim)).astype(np.float32) + expected = orig_f32.astype(np.float16).astype(np.float32) + + h5_path = tmp_path / "slide.h5" + _write_fake_h5(h5_path, n_tiles=n_tiles, dim=dim, seed=0, feature_dtype=np.float16) + + cfg = AggregateSampleFeaturesConfig( + patch_features_h5_paths=[str(h5_path)], + sample_ids=["s1"], + output_dir=str(tmp_path / "out"), + save_pt=False, + ) + mussel.cli.aggregate_sample_features.main(OmegaConf.structured(cfg)) + + with h5py.File(tmp_path / "out" / "s1.features.h5") as f: + out_features = f["features"][:] + + assert ( + out_features.dtype == np.float32 + ), f"Expected float32 output, got {out_features.dtype}" + np.testing.assert_allclose( + out_features, + expected, + rtol=1e-3, + err_msg="Float16→float32 values differ more than expected", + ) + + +def test_bfloat16_features_upcast_to_float32(tmp_path): + """bfloat16 features (stored as |V2 opaque void in h5) are upcast to float32.""" + import ml_dtypes + + n_tiles, dim = 6, 8 + rng = np.random.default_rng(0) + orig_f32 = rng.random((n_tiles, dim)).astype(np.float32) + expected = orig_f32.astype(ml_dtypes.bfloat16).astype(np.float32) + + h5_path = tmp_path / "slide.h5" + _write_fake_h5( + h5_path, n_tiles=n_tiles, dim=dim, seed=0, feature_dtype=ml_dtypes.bfloat16 + ) + + cfg = AggregateSampleFeaturesConfig( + patch_features_h5_paths=[str(h5_path)], + sample_ids=["s1"], + output_dir=str(tmp_path / "out"), + save_pt=False, + ) + mussel.cli.aggregate_sample_features.main(OmegaConf.structured(cfg)) + + with h5py.File(tmp_path / "out" / "s1.features.h5") as f: + out_features = f["features"][:] + + assert ( + out_features.dtype == np.float32 + ), f"Expected float32 output, got {out_features.dtype}" + np.testing.assert_allclose( + out_features, + expected, + rtol=1e-2, + err_msg="Bfloat16→float32 values differ more than expected", + ) From c660ed4a727094f5a7e040e81778f7e82624b9fe Mon Sep 17 00:00:00 2001 From: Raymond Lim Date: Thu, 7 May 2026 14:09:29 -0400 Subject: [PATCH 2/6] feat: add abmil_benchmark CLI for precision benchmarking MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a new abmil_benchmark CLI entry point to the Mussel package that trains a Gated-ABMIL classifier on per-slide H5 feature files and reports AUROC across multiple seeds with bootstrap 95% CIs. Key design: - Follows linear_probe_benchmark conventions (Hydra config, ConfigStore, same JSON output format) - Reads per-slide H5 files (h5["features"] shape: n_tiles × feature_dim) - cfg.dtype (float32/float16/bfloat16) cast at load time to simulate the precision loss from reduced-precision feature storage - AbmilClassifier wraps mussel.models.abmil.ABMIL (n_branches=1) with a linear head; model trains in float32 regardless of input dtype - Supports optional pre-defined split column or random slide-level split - Initialises best_state to model weights before training so load_state_dict never receives None when val set is single-class Files changed: - mussel/cli/abmil_benchmark.py (new) - pyproject.toml: add abmil_benchmark entry point - tests/mussel/cli/test_abmil_benchmark.py (new, 9 tests) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- mussel/cli/abmil_benchmark.py | 577 +++++++++++++++++++++++ pyproject.toml | 1 + tests/mussel/cli/test_abmil_benchmark.py | 221 +++++++++ 3 files changed, 799 insertions(+) create mode 100644 mussel/cli/abmil_benchmark.py create mode 100644 tests/mussel/cli/test_abmil_benchmark.py diff --git a/mussel/cli/abmil_benchmark.py b/mussel/cli/abmil_benchmark.py new file mode 100644 index 00000000..c960497a --- /dev/null +++ b/mussel/cli/abmil_benchmark.py @@ -0,0 +1,577 @@ +"""ABMIL benchmark: evaluate WSI classification AUROC from patch features. + +Measures the effect of feature storage precision (float32/float16/bfloat16) on +slide-level classification performance, using Attention-Based MIL (ABMIL) as the +aggregation model. + +Usage:: + + abmil_benchmark \\ + features_dir=/path/to/feature/h5s \\ + labels_parquet=/path/to/labels.parquet \\ + target_col=STK11_ONCOGENIC \\ + dtype=float16 \\ + output_summary_json=results_float16.json +""" + +import json +import logging +from dataclasses import dataclass, field +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import h5py +import hydra +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +from hydra.conf import HelpConf, HydraConf +from hydra.core.config_store import ConfigStore +from omegaconf import MISSING +from sklearn.metrics import roc_auc_score +from torch.utils.data import DataLoader, Dataset + +from mussel.models.abmil import ABMIL +from mussel.utils.feature_extract import _numpy_to_torch + +logger = logging.getLogger(__name__) + +_VALID_DTYPES = ("float32", "float16", "bfloat16") +_TORCH_DTYPE = { + "float32": torch.float32, + "float16": torch.float16, + "bfloat16": torch.bfloat16, +} + + +# ── Config ──────────────────────────────────────────────────────────────────── + + +@dataclass +class AbmilBenchmarkConfig: + """ + features_dir (str): Directory of per-slide HDF5 feature files ({slide_id}.h5). + Each file must contain a dataset named 'features' with shape (n_tiles, feature_dim). + labels_parquet (str): Parquet file with columns: slide_id, , and optionally + a split column. Rows without a matching H5 file are silently dropped. + target_col (str): Column name of the binary target label (integer 0/1). + split_col (Optional[str]): Column containing 'train'/'val'/'test' split labels. + If None, a random split is computed using test_size and val_size. + output_summary_json (str): Output path for the JSON metrics summary. + n_seeds (int): Number of random seeds. Mean +/- std is reported across seeds. + random_state (int): Primary seed; seeds random_state ... random_state+n_seeds-1 are run. + test_size (float): Fraction of slides held out as test (only used when split_col is None). + val_size (float): Fraction of slides held out as val (only used when split_col is None). + n_epochs (int): Number of training epochs per seed. + lr (float): Adam learning rate. + batch_size (int): Number of slides per batch. Tiles are zero-padded to the longest in batch. + dtype (str): Dtype to cast features to at load time; simulates low-precision storage. + Options: 'float32' (default), 'float16', 'bfloat16'. + The model always trains in float32 regardless of this setting. + head_dim (int): Hidden dimension for each ABMIL attention head. + n_heads (int): Number of independent attention heads in ABMIL. + dropout (float): Dropout probability within each attention head. + gated (bool): If True, use gated ABMIL (sigmoid gating on attention). + n_bootstrap (int): Bootstrap resamples for 95% CI on test AUROC (primary seed only). + """ + + features_dir: str = MISSING + labels_parquet: str = MISSING + target_col: str = MISSING + split_col: Optional[str] = None + output_summary_json: str = "results.json" + n_seeds: int = 3 + random_state: int = 42 + test_size: float = 0.2 + val_size: float = 0.1 + n_epochs: int = 20 + lr: float = 1e-4 + batch_size: int = 8 + dtype: str = "float32" + head_dim: int = 256 + n_heads: int = 8 + dropout: float = 0.0 + gated: bool = False + n_bootstrap: int = 1000 + + +desc_doc = """== ${hydra.help.app_name} == +Benchmark ABMIL classification on patch features extracted from whole slide images. +""" + +parameter_doc = f""" +== Available Parameters == +{AbmilBenchmarkConfig.__doc__} +""" + +cs = ConfigStore.instance() +cs.store( + group="hydra", + name="config", + node=HydraConf(help=HelpConf(header=desc_doc, footer=parameter_doc)), + provider="hydra", +) +cs.store(name="abmil_benchmark_config", node=AbmilBenchmarkConfig) + + +# ── Data ────────────────────────────────────────────────────────────────────── + + +def _load_features(h5_path: Path, cast_dtype: torch.dtype) -> torch.Tensor: + """Load patch features from an HDF5 file and cast to the requested dtype. + + Args: + h5_path: Path to the ``.h5`` file. Must contain a ``'features'`` dataset + of shape ``(n_tiles, feature_dim)``. + cast_dtype: Target torch dtype. Features are cast to this dtype after loading, + which may reduce precision (e.g. float32 → float16). + + Returns: + Tensor of shape ``(n_tiles, feature_dim)`` in the requested dtype. + """ + with h5py.File(h5_path, "r") as f: + arr = np.array(f["features"]) + features = _numpy_to_torch(arr) # handles bfloat16 |V2 transparently + return features.to(cast_dtype) + + +class SlideDataset(Dataset): + """Dataset of per-slide patch feature tensors loaded from HDF5 files. + + Args: + slide_ids: Ordered list of slide identifiers. + labels: Integer labels (0/1) aligned with ``slide_ids``. + features_dir: Directory containing ``{slide_id}.h5`` files. + cast_dtype: Torch dtype to cast features to at load time. + """ + + def __init__( + self, + slide_ids: List[str], + labels: np.ndarray, + features_dir: Path, + cast_dtype: torch.dtype, + ): + self.slide_ids = slide_ids + self.labels = labels + self.features_dir = features_dir + self.cast_dtype = cast_dtype + + def __len__(self) -> int: + return len(self.slide_ids) + + def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: + features = _load_features( + self.features_dir / f"{self.slide_ids[idx]}.h5", self.cast_dtype + ) + label = torch.tensor(self.labels[idx], dtype=torch.float32) + return features, label + + +def _collate_fn( + batch: List[Tuple[torch.Tensor, torch.Tensor]], +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Collate variable-length tile sequences into a padded batch. + + Returns: + padded: ``(B, N_max, D)`` zero-padded feature tensor (float32). + labels: ``(B,)`` label tensor. + mask: ``(B, N_max)`` boolean mask; True entries are valid tiles. + """ + features_list, labels_list = zip(*batch) + max_tiles = max(f.shape[0] for f in features_list) + feature_dim = features_list[0].shape[1] + + padded = torch.zeros(len(features_list), max_tiles, feature_dim, dtype=torch.float32) + mask = torch.zeros(len(features_list), max_tiles, dtype=torch.bool) + for i, f in enumerate(features_list): + n = f.shape[0] + padded[i, :n] = f.float() # upcast to float32; model always trains in float32 + mask[i, :n] = True + + return padded, torch.stack(labels_list), mask + + +# ── Model ───────────────────────────────────────────────────────────────────── + + +class AbmilClassifier(nn.Module): + """ABMIL attention pooling followed by a binary classification head. + + Args: + feature_dim: Input feature dimension. + head_dim: Hidden dimension for each ABMIL attention head. + n_heads: Number of independent attention heads. + dropout: Dropout inside attention heads. + gated: Enable sigmoid gating (Gated-ABMIL). + """ + + def __init__( + self, + feature_dim: int, + head_dim: int = 256, + n_heads: int = 8, + dropout: float = 0.0, + gated: bool = False, + ): + super().__init__() + self.abmil = ABMIL( + feature_dim=feature_dim, + head_dim=head_dim, + n_heads=n_heads, + dropout=dropout, + n_branches=1, + gated=gated, + ) + self.head = nn.Linear(feature_dim, 1) + + def forward( + self, x: torch.Tensor, mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """ + Args: + x: ``(B, N, D)`` patch feature tensor. + mask: ``(B, N)`` boolean mask; True = valid tile. + + Returns: + ``(B,)`` logit tensor. + """ + aggregated, _ = self.abmil(x, attn_mask=mask) # (B, 1, D) + aggregated = aggregated.squeeze(1) # (B, D) + return self.head(aggregated).squeeze(-1) # (B,) + + +# ── Splits ──────────────────────────────────────────────────────────────────── + + +def _split_by_slide( + df: pd.DataFrame, + test_size: float, + val_size: float, + seed: int, +) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: + """Random train / val / test split at slide level. + + All rows for a given slide land in the same partition to prevent data leakage. + """ + slide_ids = np.sort(df["slide_id"].unique()) + rng = np.random.RandomState(seed) + slide_ids = rng.permutation(slide_ids) + + n = len(slide_ids) + n_test = max(1, int(round(n * test_size))) + n_val = max(1, int(round(n * val_size))) + + test_ids = set(slide_ids[:n_test]) + val_ids = set(slide_ids[n_test : n_test + n_val]) + train_ids = set(slide_ids[n_test + n_val :]) + + logger.info( + "Split (seed=%d): train=%d val=%d test=%d slides", + seed, + len(train_ids), + len(val_ids), + len(test_ids), + ) + return ( + df[df["slide_id"].isin(train_ids)], + df[df["slide_id"].isin(val_ids)], + df[df["slide_id"].isin(test_ids)], + ) + + +# ── Evaluation ──────────────────────────────────────────────────────────────── + + +@torch.no_grad() +def _eval_auc( + model: nn.Module, + loader: DataLoader, + device: torch.device, +) -> float: + """Compute slide-level AUROC over a dataloader.""" + model.eval() + all_probs, all_labels = [], [] + for features, labels, mask in loader: + features = features.to(device) + mask = mask.to(device) + logits = model(features, mask) + probs = torch.sigmoid(logits).cpu().numpy() + all_probs.append(probs) + all_labels.append(labels.numpy()) + all_probs = np.concatenate(all_probs) + all_labels = np.concatenate(all_labels) + if all_labels.sum() == 0 or all_labels.sum() == len(all_labels): + logger.warning("AUROC undefined: only one class present in labels.") + return float("nan") + return float(roc_auc_score(all_labels, all_probs)) + + +def _bootstrap_ci_auc( + probs: np.ndarray, + labels: np.ndarray, + n_bootstrap: int, + seed: int, +) -> Tuple[float, float]: + """Bootstrap 95% CI for AUROC.""" + rng = np.random.RandomState(seed) + aucs = [] + n = len(labels) + for _ in range(n_bootstrap): + idx = rng.randint(0, n, size=n) + y_bs, p_bs = labels[idx], probs[idx] + if y_bs.sum() in (0, n): + continue + aucs.append(roc_auc_score(y_bs, p_bs)) + if not aucs: + return float("nan"), float("nan") + lo, hi = float(np.percentile(aucs, 2.5)), float(np.percentile(aucs, 97.5)) + return lo, hi + + +# ── Training ────────────────────────────────────────────────────────────────── + + +def _run_one_seed( + cfg: AbmilBenchmarkConfig, + df: pd.DataFrame, + features_dir: Path, + device: torch.device, + seed: int, +) -> Tuple[Dict, np.ndarray, np.ndarray]: + """Train ABMIL for one seed. Returns metrics dict and primary-seed arrays. + + Returns: + metrics: dict with 'val' and 'test' sub-dicts containing 'auroc'. + test_probs: predicted probabilities on the test set. + test_labels: ground-truth labels on the test set. + """ + torch.manual_seed(seed) + np.random.seed(seed) + + if cfg.split_col is not None: + train_df = df[df[cfg.split_col] == "train"].reset_index(drop=True) + val_df = df[df[cfg.split_col] == "val"].reset_index(drop=True) + test_df = df[df[cfg.split_col] == "test"].reset_index(drop=True) + logger.info( + "Using split_col=%r: train=%d val=%d test=%d slides", + cfg.split_col, + len(train_df), + len(val_df), + len(test_df), + ) + else: + train_df, val_df, test_df = _split_by_slide(df, cfg.test_size, cfg.val_size, seed) + + cast_dtype = _TORCH_DTYPE[cfg.dtype] + + def _make_loader(split_df: pd.DataFrame, shuffle: bool) -> DataLoader: + ds = SlideDataset( + split_df["slide_id"].tolist(), + split_df["y"].values, + features_dir, + cast_dtype, + ) + return DataLoader( + ds, + batch_size=cfg.batch_size, + shuffle=shuffle, + collate_fn=_collate_fn, + num_workers=min(4, len(split_df)), + pin_memory=device.type == "cuda", + ) + + # Infer feature_dim from the first slide in the dataset. + first_h5 = features_dir / f"{df['slide_id'].iloc[0]}.h5" + with h5py.File(first_h5, "r") as f: + feature_dim = f["features"].shape[1] + logger.info("feature_dim=%d dtype=%s device=%s", feature_dim, cfg.dtype, device) + + model = AbmilClassifier( + feature_dim=feature_dim, + head_dim=cfg.head_dim, + n_heads=cfg.n_heads, + dropout=cfg.dropout, + gated=cfg.gated, + ).to(device) + optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr) + criterion = nn.BCEWithLogitsLoss() + + train_loader = _make_loader(train_df, shuffle=True) + val_loader = _make_loader(val_df, shuffle=False) + test_loader = _make_loader(test_df, shuffle=False) + + best_val_auc = float("-inf") + # Initialise with current weights so load_state_dict never receives None. + best_state: Dict = {k: v.cpu().clone() for k, v in model.state_dict().items()} + + for epoch in range(cfg.n_epochs): + model.train() + epoch_loss = 0.0 + for features, labels, mask in train_loader: + features = features.to(device) + labels = labels.to(device) + mask = mask.to(device) + optimizer.zero_grad() + logits = model(features, mask) + loss = criterion(logits, labels) + loss.backward() + optimizer.step() + epoch_loss += loss.item() + + val_auc = _eval_auc(model, val_loader, device) + if val_auc > best_val_auc: + best_val_auc = val_auc + best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()} + logger.debug( + "epoch=%d/%d loss=%.4f val_auc=%.4f best=%.4f", + epoch + 1, + cfg.n_epochs, + epoch_loss / max(1, len(train_loader)), + val_auc, + best_val_auc, + ) + + model.load_state_dict(best_state) + test_auc = _eval_auc(model, test_loader, device) + logger.info("seed=%d best_val_auc=%.4f test_auc=%.4f", seed, best_val_auc, test_auc) + + # Collect raw probabilities and labels for bootstrap CI (caller uses primary seed only). + test_probs, test_labels = [], [] + model.eval() + with torch.no_grad(): + for features, labels, mask in test_loader: + logits = model(features.to(device), mask.to(device)) + test_probs.append(torch.sigmoid(logits).cpu().numpy()) + test_labels.append(labels.numpy()) + test_probs = np.concatenate(test_probs) + test_labels = np.concatenate(test_labels) + + return ( + {"val": {"auroc": best_val_auc}, "test": {"auroc": test_auc}}, + test_probs, + test_labels, + ) + + +# ── JSON helpers ────────────────────────────────────────────────────────────── + + +def _sanitize_for_json(obj): + """Recursively convert numpy scalars/arrays to native Python types.""" + if isinstance(obj, (np.integer,)): + return int(obj) + if isinstance(obj, (np.floating,)): + return float(obj) + if isinstance(obj, np.ndarray): + return obj.tolist() + if isinstance(obj, dict): + return {k: _sanitize_for_json(v) for k, v in obj.items()} + if isinstance(obj, list): + return [_sanitize_for_json(v) for v in obj] + return obj + + +# ── Main ────────────────────────────────────────────────────────────────────── + + +@hydra.main(version_base=None, config_path=".", config_name="abmil_benchmark_config") +def main(cfg: AbmilBenchmarkConfig): + """Benchmark ABMIL on WSI patch features for the given dtype and target.""" + if cfg.dtype not in _VALID_DTYPES: + raise ValueError( + f"dtype={cfg.dtype!r} is not supported. Choose from {_VALID_DTYPES}." + ) + + features_dir = Path(cfg.features_dir) + if not features_dir.is_dir(): + raise FileNotFoundError(f"features_dir does not exist: {features_dir}") + + # Load labels; drop slides without a corresponding H5 file. + df = pd.read_parquet(cfg.labels_parquet, columns=_label_columns(cfg)) + available = {p.stem for p in features_dir.glob("*.h5")} + before = len(df) + df = df[df["slide_id"].isin(available)].reset_index(drop=True) + if len(df) < before: + logger.warning( + "Dropped %d slides from labels parquet (no matching .h5 in features_dir).", + before - len(df), + ) + if df.empty: + raise RuntimeError( + "No slides remain after filtering. Check features_dir and labels_parquet." + ) + + # Rename target column to 'y'. + df = df.rename(columns={cfg.target_col: "y"}) + df["y"] = df["y"].astype(int) + pos_rate = df["y"].mean() + logger.info( + "Dataset: %d slides pos_rate=%.3f dtype=%s", len(df), pos_rate, cfg.dtype + ) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + logger.info("Using device: %s", device) + + seeds = [cfg.random_state + i for i in range(cfg.n_seeds)] + all_val_metrics: List[Dict] = [] + all_test_metrics: List[Dict] = [] + primary_test_probs: Optional[np.ndarray] = None + primary_test_labels: Optional[np.ndarray] = None + + for seed in seeds: + logger.info("── seed=%d " + "─" * 40, seed) + metrics, test_probs, test_labels = _run_one_seed( + cfg, df, features_dir, device, seed + ) + all_val_metrics.append(metrics["val"]) + all_test_metrics.append(metrics["test"]) + if seed == cfg.random_state: + primary_test_probs = test_probs + primary_test_labels = test_labels + + # ── Multi-seed summary ───────────────────────────────────────────────────── + all_keys = sorted({k for m in all_val_metrics + all_test_metrics for k in m}) + summary: Dict = { + "dtype": cfg.dtype, + "target_col": cfg.target_col, + "n_slides": int(len(df)), + "pos_rate": float(pos_rate), + "n_seeds": len(seeds), + "seeds": seeds, + } + for split_name, records in [("val", all_val_metrics), ("test", all_test_metrics)]: + summary[split_name] = {} + for key in all_keys: + vals = [r[key] for r in records if key in r] + if not vals: + continue + mean, std = float(np.mean(vals)), float(np.std(vals)) + summary[split_name][key] = {"mean": mean, "std": std} + logger.info("[%s] %s: %.4f +/- %.4f", split_name, key, mean, std) + + # ── Bootstrap CI on test AUROC (primary seed) ────────────────────────────── + if primary_test_probs is not None and not np.isnan( + summary["test"].get("auroc", {}).get("mean", float("nan")) + ): + ci_lo, ci_hi = _bootstrap_ci_auc( + primary_test_probs, + primary_test_labels, + cfg.n_bootstrap, + cfg.random_state, + ) + summary["test"]["auroc"]["bootstrap_ci_95"] = [ci_lo, ci_hi] + logger.info("[test] auroc bootstrap 95%% CI: [%.4f, %.4f]", ci_lo, ci_hi) + + with open(cfg.output_summary_json, "w") as f: + json.dump(_sanitize_for_json(summary), f, indent=2) + logger.info("Results written to %s", cfg.output_summary_json) + + +def _label_columns(cfg: AbmilBenchmarkConfig) -> List[str]: + """Determine which parquet columns to load.""" + cols = ["slide_id", cfg.target_col] + if cfg.split_col is not None: + cols.append(cfg.split_col) + return cols diff --git a/pyproject.toml b/pyproject.toml index c1768b59..d7ce0793 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -191,6 +191,7 @@ create_class_embeddings = 'mussel.cli.create_class_embeddings:main' merge_annotation_features = 'mussel.cli.merge_annotation_features:main' linear_probe_benchmark = 'mussel.cli.linear_probe_benchmark:main' clustering_benchmark = 'mussel.cli.clustering_benchmark:main' +abmil_benchmark = 'mussel.cli.abmil_benchmark:main' export_tiles = 'mussel.cli.export_tiles:main' save_model = 'mussel.cli.save_model:main' convert = 'mussel.cli.convert:main' diff --git a/tests/mussel/cli/test_abmil_benchmark.py b/tests/mussel/cli/test_abmil_benchmark.py new file mode 100644 index 00000000..460615b1 --- /dev/null +++ b/tests/mussel/cli/test_abmil_benchmark.py @@ -0,0 +1,221 @@ +"""Tests for mussel.cli.abmil_benchmark.""" + +import json +import os +from pathlib import Path + +import h5py +import numpy as np +import pandas as pd +import pytest +import torch + +from mussel.cli.abmil_benchmark import ( + AbmilBenchmarkConfig, + AbmilClassifier, + SlideDataset, + _bootstrap_ci_auc, + _collate_fn, + _load_features, + _split_by_slide, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_h5_features(path: Path, n_tiles: int = 50, feature_dim: int = 64, dtype: str = "float32", seed: int = 0): + """Write a synthetic H5 feature file.""" + rng = np.random.default_rng(seed) + arr = rng.standard_normal((n_tiles, feature_dim)).astype(dtype) + with h5py.File(path, "w") as f: + f.create_dataset("features", data=arr) + + +def _make_features_dir(tmp_path: Path, n_slides: int = 10, feature_dim: int = 64, seed: int = 0) -> Path: + """Create a directory of per-slide H5 feature files.""" + features_dir = tmp_path / "features" + features_dir.mkdir() + rng = np.random.default_rng(seed) + for i in range(n_slides): + n_tiles = rng.integers(30, 80) + _make_h5_features(features_dir / f"slide_{i:03d}.h5", int(n_tiles), feature_dim, seed=i) + return features_dir + + +def _make_labels_parquet(tmp_path: Path, n_slides: int = 10, target_col: str = "label", seed: int = 0) -> Path: + """Create a synthetic labels parquet.""" + rng = np.random.default_rng(seed) + rows = [ + {"slide_id": f"slide_{i:03d}", target_col: int(rng.integers(0, 2))} + for i in range(n_slides) + ] + df = pd.DataFrame(rows) + path = tmp_path / "labels.parquet" + df.to_parquet(path) + return path + + +# --------------------------------------------------------------------------- +# Unit tests: data loading +# --------------------------------------------------------------------------- + + +def test_load_features_float32(tmp_path): + h5_path = tmp_path / "slide.h5" + _make_h5_features(h5_path, n_tiles=40, feature_dim=16, dtype="float32") + features = _load_features(h5_path, torch.float32) + assert features.shape == (40, 16) + assert features.dtype == torch.float32 + + +def test_load_features_cast_to_float16(tmp_path): + h5_path = tmp_path / "slide.h5" + _make_h5_features(h5_path, n_tiles=40, feature_dim=16, dtype="float32") + features = _load_features(h5_path, torch.float16) + assert features.dtype == torch.float16 + # Values should differ from float32 due to reduced precision. + features_f32 = _load_features(h5_path, torch.float32) + assert not torch.equal(features.float(), features_f32) + + +def test_slide_dataset(tmp_path): + features_dir = _make_features_dir(tmp_path, n_slides=5, feature_dim=32) + slide_ids = [f"slide_{i:03d}" for i in range(5)] + labels = np.array([0, 1, 0, 1, 0]) + ds = SlideDataset(slide_ids, labels, features_dir, cast_dtype=torch.float32) + assert len(ds) == 5 + features, label = ds[2] + assert features.ndim == 2 + assert features.shape[1] == 32 + assert label.item() == 0 + + +def test_collate_fn_pads_correctly(): + """Batch of slides with different tile counts should be padded correctly.""" + feat_dim = 8 + b1 = torch.ones(10, feat_dim) + b2 = torch.ones(20, feat_dim) * 2 + b3 = torch.ones(15, feat_dim) * 3 + labels = [torch.tensor(0.0), torch.tensor(1.0), torch.tensor(0.0)] + batch = [(b1, labels[0]), (b2, labels[1]), (b3, labels[2])] + padded, lbls, mask = _collate_fn(batch) + assert padded.shape == (3, 20, feat_dim) + assert mask.shape == (3, 20) + assert padded.dtype == torch.float32 + assert mask[0, :10].all() and not mask[0, 10:].any() + assert mask[1, :20].all() + assert mask[2, :15].all() and not mask[2, 15:].any() + # Padding entries must be zero. + assert (padded[0, 10:] == 0).all() + + +# --------------------------------------------------------------------------- +# Unit tests: model +# --------------------------------------------------------------------------- + + +def test_abmil_classifier_forward(): + model = AbmilClassifier(feature_dim=32, head_dim=16, n_heads=2) + model.eval() + B, N, D = 4, 25, 32 + x = torch.randn(B, N, D) + mask = torch.ones(B, N, dtype=torch.bool) + with torch.no_grad(): + logits = model(x, mask) + assert logits.shape == (B,) + + +def test_abmil_classifier_with_padding(): + model = AbmilClassifier(feature_dim=16, head_dim=8, n_heads=1) + model.eval() + B, N_max, D = 2, 30, 16 + x = torch.randn(B, N_max, D) + # Slide 0 has 20 valid tiles, slide 1 has 30. + mask = torch.zeros(B, N_max, dtype=torch.bool) + mask[0, :20] = True + mask[1, :30] = True + with torch.no_grad(): + logits = model(x, mask) + assert logits.shape == (B,) + + +# --------------------------------------------------------------------------- +# Unit tests: splits +# --------------------------------------------------------------------------- + + +def test_split_by_slide_no_leakage(): + slide_ids = [f"slide_{i:03d}" for i in range(30)] + labels = [i % 2 for i in range(30)] + df = pd.DataFrame({"slide_id": slide_ids, "y": labels}) + train_df, val_df, test_df = _split_by_slide(df, test_size=0.2, val_size=0.1, seed=42) + all_ids = set(train_df["slide_id"]) | set(val_df["slide_id"]) | set(test_df["slide_id"]) + assert all_ids == set(slide_ids) + # No overlap between splits. + assert not (set(train_df["slide_id"]) & set(val_df["slide_id"])) + assert not (set(train_df["slide_id"]) & set(test_df["slide_id"])) + assert not (set(val_df["slide_id"]) & set(test_df["slide_id"])) + + +# --------------------------------------------------------------------------- +# Unit tests: bootstrap CI +# --------------------------------------------------------------------------- + + +def test_bootstrap_ci_auc(): + rng = np.random.default_rng(0) + probs = rng.uniform(0, 1, 100) + labels = (probs > 0.5).astype(int) + lo, hi = _bootstrap_ci_auc(probs, labels, n_bootstrap=200, seed=0) + assert lo < hi + assert 0.5 <= lo <= 1.0 + assert 0.5 <= hi <= 1.0 + + +# --------------------------------------------------------------------------- +# Integration test: end-to-end (CPU, tiny dataset) +# --------------------------------------------------------------------------- + + +def test_abmil_benchmark_end_to_end(tmp_path, monkeypatch): + """Smoke test: run one seed of training on a tiny synthetic dataset.""" + n_slides = 12 + feature_dim = 16 + features_dir = _make_features_dir(tmp_path, n_slides=n_slides, feature_dim=feature_dim) + labels_path = _make_labels_parquet(tmp_path, n_slides=n_slides) + output_path = str(tmp_path / "results.json") + + from mussel.cli.abmil_benchmark import AbmilBenchmarkConfig, _run_one_seed + + cfg = AbmilBenchmarkConfig( + features_dir=str(features_dir), + labels_parquet=str(labels_path), + target_col="label", + output_summary_json=output_path, + n_seeds=1, + random_state=0, + n_epochs=2, + batch_size=4, + dtype="float32", + head_dim=8, + n_heads=1, + ) + + df = pd.read_parquet(labels_path) + df = df.rename(columns={"label": "y"}) + df["y"] = df["y"].astype(int) + + metrics, test_probs, test_labels = _run_one_seed( + cfg, df, features_dir, torch.device("cpu"), seed=0 + ) + assert "auroc" in metrics["val"] + assert "auroc" in metrics["test"] + # With random labels, AUROC might be anywhere; just check it's a valid float. + val_auc = metrics["val"]["auroc"] + test_auc = metrics["test"]["auroc"] + assert isinstance(val_auc, float) + assert isinstance(test_auc, float) + assert test_probs.shape == test_labels.shape From 2d7df437bbfc6c78dde5c9a0298d138788e93546 Mon Sep 17 00:00:00 2001 From: Raymond Lim Date: Thu, 7 May 2026 14:33:29 -0400 Subject: [PATCH 3/6] fix: sanitize inf/nan in JSON output to null _sanitize_for_json now replaces inf and nan float values with None so the serialised JSON is valid (standard JSON has no representation for these special float values). Groovy's JsonSlurper would otherwise fail to parse results.json when val AUROC is undefined (single-class val split). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- mussel/cli/abmil_benchmark.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/mussel/cli/abmil_benchmark.py b/mussel/cli/abmil_benchmark.py index c960497a..965e0f32 100644 --- a/mussel/cli/abmil_benchmark.py +++ b/mussel/cli/abmil_benchmark.py @@ -459,11 +459,18 @@ def _make_loader(split_df: pd.DataFrame, shuffle: bool) -> DataLoader: def _sanitize_for_json(obj): - """Recursively convert numpy scalars/arrays to native Python types.""" + """Recursively convert numpy scalars/arrays to native Python types. + + Replaces ``inf`` and ``nan`` float values with ``None`` so the output is + valid JSON (standard JSON has no representation for these values). + """ if isinstance(obj, (np.integer,)): return int(obj) if isinstance(obj, (np.floating,)): - return float(obj) + v = float(obj) + return None if (np.isnan(v) or np.isinf(v)) else v + if isinstance(obj, float): + return None if (np.isnan(obj) or np.isinf(obj)) else obj if isinstance(obj, np.ndarray): return obj.tolist() if isinstance(obj, dict): From 2161e006849784ff7502bc6b92d33ee3da994d1e Mon Sep 17 00:00:00 2001 From: Raymond Lim Date: Wed, 17 Jun 2026 14:06:47 -0400 Subject: [PATCH 4/6] fix: guard _eval_auc against empty loader; use final-epoch weights when val AUC is always NaN - _eval_auc: return nan early if loader has zero batches to avoid np.concatenate([]) ValueError - _train_one_seed: use math.isnan guard so NaN val AUC never beats float('-inf'); initialize best_state=None and skip load_state_dict when val AUC was always undefined (single-class val split), keeping final-epoch weights instead of reverting to random initial weights Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- mussel/cli/abmil_benchmark.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/mussel/cli/abmil_benchmark.py b/mussel/cli/abmil_benchmark.py index 965e0f32..2166180b 100644 --- a/mussel/cli/abmil_benchmark.py +++ b/mussel/cli/abmil_benchmark.py @@ -16,6 +16,7 @@ import json import logging +import math from dataclasses import dataclass, field from pathlib import Path from typing import Dict, List, Optional, Tuple @@ -300,6 +301,9 @@ def _eval_auc( probs = torch.sigmoid(logits).cpu().numpy() all_probs.append(probs) all_labels.append(labels.numpy()) + if not all_probs: + logger.warning("AUROC undefined: loader is empty.") + return float("nan") all_probs = np.concatenate(all_probs) all_labels = np.concatenate(all_labels) if all_labels.sum() == 0 or all_labels.sum() == len(all_labels): @@ -403,8 +407,7 @@ def _make_loader(split_df: pd.DataFrame, shuffle: bool) -> DataLoader: test_loader = _make_loader(test_df, shuffle=False) best_val_auc = float("-inf") - # Initialise with current weights so load_state_dict never receives None. - best_state: Dict = {k: v.cpu().clone() for k, v in model.state_dict().items()} + best_state: Optional[Dict] = None for epoch in range(cfg.n_epochs): model.train() @@ -421,7 +424,7 @@ def _make_loader(split_df: pd.DataFrame, shuffle: bool) -> DataLoader: epoch_loss += loss.item() val_auc = _eval_auc(model, val_loader, device) - if val_auc > best_val_auc: + if not math.isnan(val_auc) and val_auc > best_val_auc: best_val_auc = val_auc best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()} logger.debug( @@ -433,7 +436,10 @@ def _make_loader(split_df: pd.DataFrame, shuffle: bool) -> DataLoader: best_val_auc, ) - model.load_state_dict(best_state) + # Restore best weights; if val AUC was always undefined (single-class val split), + # best_state is None and we keep the final-epoch weights instead. + if best_state is not None: + model.load_state_dict(best_state) test_auc = _eval_auc(model, test_loader, device) logger.info("seed=%d best_val_auc=%.4f test_auc=%.4f", seed, best_val_auc, test_auc) From 2d7988a9e5ef20c4de23ebfbefe4224abfad75bc Mon Sep 17 00:00:00 2001 From: Raymond Lim Date: Wed, 17 Jun 2026 14:13:53 -0400 Subject: [PATCH 5/6] fix: address PR review comments in abmil_benchmark - _split_by_slide: raise ValueError when n_test+n_val >= n (empty train split) - _train_one_seed: validate split_col exists and all three splits are non-empty - _make_loader: add multiprocessing_context='spawn' when num_workers>0 to prevent CUDA context corruption (mirrors feature_extract._make_dataloader) - emit NaN (not -inf) as val auroc when val AUC was always undefined; _sanitize_for_json already converts NaN->null for valid JSON output Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- mussel/cli/abmil_benchmark.py | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/mussel/cli/abmil_benchmark.py b/mussel/cli/abmil_benchmark.py index 2166180b..824ea0a0 100644 --- a/mussel/cli/abmil_benchmark.py +++ b/mussel/cli/abmil_benchmark.py @@ -264,6 +264,14 @@ def _split_by_slide( n_test = max(1, int(round(n * test_size))) n_val = max(1, int(round(n * val_size))) + if n_test + n_val >= n: + raise ValueError( + f"Too few slides ({n}) for the requested split sizes " + f"(test_size={test_size}, val_size={val_size}). " + f"Need at least {n_test + n_val + 1} slides but got {n}. " + "Reduce test_size/val_size or provide more slides." + ) + test_ids = set(slide_ids[:n_test]) val_ids = set(slide_ids[n_test : n_test + n_val]) train_ids = set(slide_ids[n_test + n_val :]) @@ -355,9 +363,20 @@ def _run_one_seed( np.random.seed(seed) if cfg.split_col is not None: + if cfg.split_col not in df.columns: + raise ValueError( + f"split_col={cfg.split_col!r} not found in parquet columns: {list(df.columns)}" + ) train_df = df[df[cfg.split_col] == "train"].reset_index(drop=True) val_df = df[df[cfg.split_col] == "val"].reset_index(drop=True) test_df = df[df[cfg.split_col] == "test"].reset_index(drop=True) + for split_name, split_df in [("train", train_df), ("val", val_df), ("test", test_df)]: + if split_df.empty: + raise ValueError( + f"split_col={cfg.split_col!r}: '{split_name}' split is empty. " + "Ensure the parquet contains rows with " + f"{cfg.split_col}='{split_name}'." + ) logger.info( "Using split_col=%r: train=%d val=%d test=%d slides", cfg.split_col, @@ -377,12 +396,14 @@ def _make_loader(split_df: pd.DataFrame, shuffle: bool) -> DataLoader: features_dir, cast_dtype, ) + num_workers = min(4, len(split_df)) return DataLoader( ds, batch_size=cfg.batch_size, shuffle=shuffle, collate_fn=_collate_fn, - num_workers=min(4, len(split_df)), + num_workers=num_workers, + multiprocessing_context="spawn" if num_workers > 0 else None, pin_memory=device.type == "cuda", ) @@ -455,7 +476,10 @@ def _make_loader(split_df: pd.DataFrame, shuffle: bool) -> DataLoader: test_labels = np.concatenate(test_labels) return ( - {"val": {"auroc": best_val_auc}, "test": {"auroc": test_auc}}, + { + "val": {"auroc": best_val_auc if math.isfinite(best_val_auc) else float("nan")}, + "test": {"auroc": test_auc}, + }, test_probs, test_labels, ) From b7c18b07d08aeebe3d955443d7e45b377e7dc6c8 Mon Sep 17 00:00:00 2001 From: Raymond Lim Date: Wed, 17 Jun 2026 14:58:07 -0400 Subject: [PATCH 6/6] fix: use patch.object(mp, 'cpu_count') to avoid importlib attribute chain issue Under --import-mode=importlib, mock's _dot_lookup can fail to resolve 'mussel.utils.converter' as an attribute of 'mussel.utils' in Python 3.10 when running the full test suite. Replace the fragile string-based patch target with patch.object(mp, 'cpu_count') using the multiprocessing module already imported at the top of the test file. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- tests/mussel/cli/test_convert.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/mussel/cli/test_convert.py b/tests/mussel/cli/test_convert.py index 5d49dac5..f63016bb 100644 --- a/tests/mussel/cli/test_convert.py +++ b/tests/mussel/cli/test_convert.py @@ -485,7 +485,7 @@ def test_num_workers_zero_uses_cpu_count(self, tmp_path): c = self._make_converter(tmp_path) with patch.object(c, "process_file") as mock_pf, \ - patch("mussel.utils.converter.mp.cpu_count", return_value=1): + patch.object(mp, "cpu_count", return_value=1): c.process_all(str(input_dir), mpp_csv=csv_path, num_workers=0) mock_pf.assert_called_once()