Skip to content

feat: add abmil_benchmark CLI for precision benchmarking#124

Open
raylim wants to merge 7 commits into
mainfrom
feat/abmil-benchmark
Open

feat: add abmil_benchmark CLI for precision benchmarking#124
raylim wants to merge 7 commits into
mainfrom
feat/abmil-benchmark

Conversation

@raylim

@raylim raylim commented May 7, 2026

Copy link
Copy Markdown
Collaborator

Summary

Adds a new abmil_benchmark CLI to the Mussel package that trains a Gated-ABMIL classifier on per-slide H5 feature files and reports AUROC across multiple seeds with bootstrap 95% CIs.

This supports benchmarking how lower-precision feature storage (float32 / float16 / bfloat16) affects downstream slide-level ABMIL classification performance, complementing the existing tile-level linear_probe_benchmark.

Used by: mussel-nf PR #20 — wires this CLI into the precision-benchmarking pipeline as ABMIL_BENCHMARK_WORKFLOW


Design

  • Follows linear_probe_benchmark conventions (Hydra config, ConfigStore, same JSON output format)
  • Reads per-slide H5 files (h5["features"] shape: n_tiles × feature_dim)
  • cfg.dtype (float32 / float16 / bfloat16) cast at load time to simulate information loss from reduced-precision storage; model always trains in float32 for a clean benchmark signal
  • AbmilClassifier wraps mussel.models.abmil.ABMIL (n_branches=1) + nn.Linear head
  • Supports an optional pre-defined split column in the labels parquet, or falls back to a random slide-level split
  • bfloat16 H5 storage handled via _numpy_to_torch (|V2 opaque void detection)

Output format (results.json)

{
  "dtype": "float32",
  "target_col": "label",
  "n_slides": 200,
  "pos_rate": 0.48,
  "n_seeds": 3,
  "seeds": [42, 43, 44],
  "val":  { "auroc": { "mean": 0.82, "std": 0.03 } },
  "test": { "auroc": { "mean": 0.79, "std": 0.02, "bootstrap_ci_95": [0.71, 0.86] } }
}

Val AUROC mean will be null when the val split is single-class (too few slides). The Nextflow summary process handles this gracefully.

CLI usage

abmil_benchmark \
    features_dir=/path/to/h5s \
    labels_parquet=/path/to/labels.parquet \
    target_col=my_binary_label \
    dtype=float16 \
    n_seeds=3 \
    n_epochs=20 \
    output_summary_json=results.json

Files changed

File Change
mussel/cli/abmil_benchmark.py New CLI (~400 lines)
pyproject.toml Add abmil_benchmark entry point
tests/mussel/cli/test_abmil_benchmark.py New test file (9 tests)

Tests

All 9 tests pass:

Test What it covers
test_load_h5_features_float32 H5 loading, float32 passthrough
test_load_h5_features_float16_cast dtype cast at load time
test_slide_dataset_getitem SlideDataset.__getitem__
test_collate_fn_padding padding to N_max, mask shape
test_collate_fn_mask_correctness mask marks padded positions
test_abmil_classifier_forward model forward without padding
test_abmil_classifier_forward_with_padding model forward with mask
test_split_no_leakage slide IDs don't appear in multiple splits
test_bootstrap_ci 95% CI bounds are ordered and within [0,1]
test_end_to_end smoke test: CPU, synthetic dataset, valid JSON output

Bug fixes (discovered during testing)

  1. best_state initialization: initialized to None, causing a crash when the val AUROC was always undefined (single-class val split). Fixed: initialize to the model's initial weights before the training loop.
  2. _sanitize_for_json inf/nan: json.dumps(float('-inf')) produces -Infinity — not valid JSON. Groovy's JsonSlurper (used in Nextflow) raises JsonInternalException. Fixed: replace inf/nan with None before serialization.

raylim and others added 2 commits May 7, 2026 13:50
…features

Same raw h5 read pattern as merge_annotation_features (line 142).
After np.array(h5["features"]), cast bfloat16 (|V2 opaque void) and
float16 to float32 before passing to downstream aggregation.

- bfloat16 stored as |V2 via ml_dtypes: view + astype(float32)
- float16: direct astype(float32)

Tests added:
- test_float16_features_upcast_to_float32
- test_bfloat16_features_upcast_to_float32

All 17 tests pass.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Adds a new abmil_benchmark CLI entry point to the Mussel package that
trains a Gated-ABMIL classifier on per-slide H5 feature files and
reports AUROC across multiple seeds with bootstrap 95% CIs.

Key design:
- Follows linear_probe_benchmark conventions (Hydra config, ConfigStore,
  same JSON output format)
- Reads per-slide H5 files (h5["features"] shape: n_tiles × feature_dim)
- cfg.dtype (float32/float16/bfloat16) cast at load time to simulate
  the precision loss from reduced-precision feature storage
- AbmilClassifier wraps mussel.models.abmil.ABMIL (n_branches=1)
  with a linear head; model trains in float32 regardless of input dtype
- Supports optional pre-defined split column or random slide-level split
- Initialises best_state to model weights before training so
  load_state_dict never receives None when val set is single-class

Files changed:
- mussel/cli/abmil_benchmark.py (new)
- pyproject.toml: add abmil_benchmark entry point
- tests/mussel/cli/test_abmil_benchmark.py (new, 9 tests)

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Copilot AI review requested due to automatic review settings May 7, 2026 18:10

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a new abmil_benchmark command-line tool to benchmark ABMIL slide classification performance across simulated feature-storage precisions, plus supporting tests and a small precision-handling fix in aggregate_sample_features.

Changes:

  • Introduces mussel.cli.abmil_benchmark (Hydra-configured) to train/evaluate an ABMIL classifier over per-slide H5 features and report AUROC across seeds with bootstrap CIs.
  • Adds an abmil_benchmark console script entry point.
  • Extends aggregate_sample_features to upcast float16/bfloat16 features to float32 on read, with new regression tests.

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
mussel/cli/abmil_benchmark.py New benchmark CLI: H5 feature loading, slide dataset/collate, ABMIL classifier, split/train/eval, JSON summary output.
pyproject.toml Registers abmil_benchmark as a package entry point.
tests/mussel/cli/test_abmil_benchmark.py Adds unit + integration smoke tests for the new CLI components.
mussel/cli/aggregate_sample_features.py Upcasts float16/bfloat16 H5 feature arrays to float32 for downstream compatibility.
tests/mussel/cli/test_aggregate_sample_features.py Adds tests validating float16/bfloat16 upcast behavior; minor formatting updates.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread mussel/cli/abmil_benchmark.py
Comment thread mussel/cli/abmil_benchmark.py
Comment thread mussel/cli/abmil_benchmark.py
Comment thread mussel/cli/abmil_benchmark.py
Comment thread mussel/cli/abmil_benchmark.py
_sanitize_for_json now replaces inf and nan float values with None
so the serialised JSON is valid (standard JSON has no representation
for these special float values). Groovy's JsonSlurper would otherwise
fail to parse results.json when val AUROC is undefined (single-class
val split).

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Comment thread mussel/cli/abmil_benchmark.py
Comment thread mussel/cli/abmil_benchmark.py Outdated
raylim and others added 4 commits June 17, 2026 14:06
…en val AUC is always NaN

- _eval_auc: return nan early if loader has zero batches to avoid
  np.concatenate([]) ValueError
- _train_one_seed: use math.isnan guard so NaN val AUC never beats
  float('-inf'); initialize best_state=None and skip load_state_dict
  when val AUC was always undefined (single-class val split), keeping
  final-epoch weights instead of reverting to random initial weights

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
- _split_by_slide: raise ValueError when n_test+n_val >= n (empty train split)
- _train_one_seed: validate split_col exists and all three splits are non-empty
- _make_loader: add multiprocessing_context='spawn' when num_workers>0 to
  prevent CUDA context corruption (mirrors feature_extract._make_dataloader)
- emit NaN (not -inf) as val auroc when val AUC was always undefined;
  _sanitize_for_json already converts NaN->null for valid JSON output

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
# Conflicts:
#	mussel/cli/aggregate_sample_features.py
#	tests/mussel/cli/test_aggregate_sample_features.py
…hain issue

Under --import-mode=importlib, mock's _dot_lookup can fail to resolve
'mussel.utils.converter' as an attribute of 'mussel.utils' in Python 3.10
when running the full test suite. Replace the fragile string-based patch
target with patch.object(mp, 'cpu_count') using the multiprocessing module
already imported at the top of the test file.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants