feat: add abmil_benchmark CLI for precision benchmarking#124
Open
raylim wants to merge 7 commits into
Open
Conversation
…features Same raw h5 read pattern as merge_annotation_features (line 142). After np.array(h5["features"]), cast bfloat16 (|V2 opaque void) and float16 to float32 before passing to downstream aggregation. - bfloat16 stored as |V2 via ml_dtypes: view + astype(float32) - float16: direct astype(float32) Tests added: - test_float16_features_upcast_to_float32 - test_bfloat16_features_upcast_to_float32 All 17 tests pass. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Adds a new abmil_benchmark CLI entry point to the Mussel package that trains a Gated-ABMIL classifier on per-slide H5 feature files and reports AUROC across multiple seeds with bootstrap 95% CIs. Key design: - Follows linear_probe_benchmark conventions (Hydra config, ConfigStore, same JSON output format) - Reads per-slide H5 files (h5["features"] shape: n_tiles × feature_dim) - cfg.dtype (float32/float16/bfloat16) cast at load time to simulate the precision loss from reduced-precision feature storage - AbmilClassifier wraps mussel.models.abmil.ABMIL (n_branches=1) with a linear head; model trains in float32 regardless of input dtype - Supports optional pre-defined split column or random slide-level split - Initialises best_state to model weights before training so load_state_dict never receives None when val set is single-class Files changed: - mussel/cli/abmil_benchmark.py (new) - pyproject.toml: add abmil_benchmark entry point - tests/mussel/cli/test_abmil_benchmark.py (new, 9 tests) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Contributor
There was a problem hiding this comment.
Pull request overview
Adds a new abmil_benchmark command-line tool to benchmark ABMIL slide classification performance across simulated feature-storage precisions, plus supporting tests and a small precision-handling fix in aggregate_sample_features.
Changes:
- Introduces
mussel.cli.abmil_benchmark(Hydra-configured) to train/evaluate an ABMIL classifier over per-slide H5 features and report AUROC across seeds with bootstrap CIs. - Adds an
abmil_benchmarkconsole script entry point. - Extends
aggregate_sample_featuresto upcast float16/bfloat16 features to float32 on read, with new regression tests.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
mussel/cli/abmil_benchmark.py |
New benchmark CLI: H5 feature loading, slide dataset/collate, ABMIL classifier, split/train/eval, JSON summary output. |
pyproject.toml |
Registers abmil_benchmark as a package entry point. |
tests/mussel/cli/test_abmil_benchmark.py |
Adds unit + integration smoke tests for the new CLI components. |
mussel/cli/aggregate_sample_features.py |
Upcasts float16/bfloat16 H5 feature arrays to float32 for downstream compatibility. |
tests/mussel/cli/test_aggregate_sample_features.py |
Adds tests validating float16/bfloat16 upcast behavior; minor formatting updates. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
_sanitize_for_json now replaces inf and nan float values with None so the serialised JSON is valid (standard JSON has no representation for these special float values). Groovy's JsonSlurper would otherwise fail to parse results.json when val AUROC is undefined (single-class val split). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
raylim
commented
Jun 17, 2026
raylim
commented
Jun 17, 2026
…en val AUC is always NaN
- _eval_auc: return nan early if loader has zero batches to avoid
np.concatenate([]) ValueError
- _train_one_seed: use math.isnan guard so NaN val AUC never beats
float('-inf'); initialize best_state=None and skip load_state_dict
when val AUC was always undefined (single-class val split), keeping
final-epoch weights instead of reverting to random initial weights
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
- _split_by_slide: raise ValueError when n_test+n_val >= n (empty train split) - _train_one_seed: validate split_col exists and all three splits are non-empty - _make_loader: add multiprocessing_context='spawn' when num_workers>0 to prevent CUDA context corruption (mirrors feature_extract._make_dataloader) - emit NaN (not -inf) as val auroc when val AUC was always undefined; _sanitize_for_json already converts NaN->null for valid JSON output Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
# Conflicts: # mussel/cli/aggregate_sample_features.py # tests/mussel/cli/test_aggregate_sample_features.py
…hain issue Under --import-mode=importlib, mock's _dot_lookup can fail to resolve 'mussel.utils.converter' as an attribute of 'mussel.utils' in Python 3.10 when running the full test suite. Replace the fragile string-based patch target with patch.object(mp, 'cpu_count') using the multiprocessing module already imported at the top of the test file. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds a new
abmil_benchmarkCLI to the Mussel package that trains a Gated-ABMIL classifier on per-slide H5 feature files and reports AUROC across multiple seeds with bootstrap 95% CIs.This supports benchmarking how lower-precision feature storage (float32 / float16 / bfloat16) affects downstream slide-level ABMIL classification performance, complementing the existing tile-level
linear_probe_benchmark.Design
linear_probe_benchmarkconventions (Hydra config, ConfigStore, same JSON output format)h5["features"]shape:n_tiles × feature_dim)cfg.dtype(float32 / float16 / bfloat16) cast at load time to simulate information loss from reduced-precision storage; model always trains in float32 for a clean benchmark signalAbmilClassifierwrapsmussel.models.abmil.ABMIL(n_branches=1) +nn.Linearhead_numpy_to_torch(|V2opaque void detection)Output format (
results.json){ "dtype": "float32", "target_col": "label", "n_slides": 200, "pos_rate": 0.48, "n_seeds": 3, "seeds": [42, 43, 44], "val": { "auroc": { "mean": 0.82, "std": 0.03 } }, "test": { "auroc": { "mean": 0.79, "std": 0.02, "bootstrap_ci_95": [0.71, 0.86] } } }Val AUROC
meanwill benullwhen the val split is single-class (too few slides). The Nextflow summary process handles this gracefully.CLI usage
Files changed
mussel/cli/abmil_benchmark.pypyproject.tomlabmil_benchmarkentry pointtests/mussel/cli/test_abmil_benchmark.pyTests
All 9 tests pass:
test_load_h5_features_float32test_load_h5_features_float16_casttest_slide_dataset_getitemSlideDataset.__getitem__test_collate_fn_paddingN_max, mask shapetest_collate_fn_mask_correctnesstest_abmil_classifier_forwardtest_abmil_classifier_forward_with_paddingtest_split_no_leakagetest_bootstrap_citest_end_to_endBug fixes (discovered during testing)
best_stateinitialization: initialized toNone, causing a crash when the val AUROC was always undefined (single-class val split). Fixed: initialize to the model's initial weights before the training loop._sanitize_for_jsoninf/nan:json.dumps(float('-inf'))produces-Infinity— not valid JSON. Groovy'sJsonSlurper(used in Nextflow) raisesJsonInternalException. Fixed: replaceinf/nanwithNonebefore serialization.