Skip to content

feat: GPU float16 monkey-patches to fix TITAN OOM on large IMPACT slides#133

Open
raylim wants to merge 24 commits into
mainfrom
feat/titan-get-alibi-gpu
Open

feat: GPU float16 monkey-patches to fix TITAN OOM on large IMPACT slides#133
raylim wants to merge 24 commits into
mainfrom
feat/titan-get-alibi-gpu

Conversation

@raylim

@raylim raylim commented Jun 16, 2026

Copy link
Copy Markdown
Collaborator

Problem

TITAN's get_alibi() creates O(N²) numpy float64 arrays on CPU, causing SLURM OOM for large IMPACT resection specimens (>25k patches). For N=33k: ~82 GB CPU RAM peak. Affects ~15% of IMPACT slides.

See Confluence: https://mskconfluence.mskcc.org/pages/viewpage.action?pageId=259719289

Changes — mussel/models/conch.py

Three monkey-patches applied in TitanSlideEncoderModel.get_model_fun():

  1. _titan_get_alibi_gpu_float16 — replaces numpy float64 broadcast with torch.cdist float16 on GPU. Eliminates the 17 GB intermediate (N,N,2) array. Peak: 82 GB CPU → 8 GB GPU for N=18k (typical IMPACT foreground patches).

  2. _titan_forward_features_efficient — replaces .repeat(B,1,1,1) with .expand() in forward_features. Avoids a 22 GB copy of the bias tensor.

  3. sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION) context in model_fun — forces xformers tiled attention, prevents QK^T matrix materialisation (~22 GB per layer with math kernel).

Memory budget (A100 80 GB, N_fg=18k typical IMPACT foreground)

Component Before After
get_alibi CPU RAM 82 GB ~0 GB
bias on GPU 52 GB (float32) 7.8 GB (float16)
QK^T per layer 22 GB (math) <1 GB (efficient tiled)
Total peak GPU OOM ~12 GB

Testing

  • Unit tests: 13/13 pass (CPU, no model weights)
  • Integration test (V100): 3/3 pass — VRAM peak 3.2 GB for N=5k, CPU RAM delta 0.0 GB for N=10k
  • Full suite: 458 passed, 0 failed, 3 skipped

raylim and others added 24 commits June 16, 2026 14:08
TITAN's get_alibi() creates O(N²) numpy float64 arrays on CPU, causing SLURM
OOM for large IMPACT resection specimens (>25k patches, ~82 GB for N=33k).

Fix (two monkey-patches applied in TitanSlideEncoderModel.get_model_fun()):

1. _titan_get_alibi_gpu_float16: replaces numpy float64 broadcast with
   torch.cdist in float16 on GPU. Eliminates the 17 GB (N,N,2) intermediate.
   Peak memory: 82 GB CPU → 26 GB GPU for N=33k. Covers N≤45k on A100.

2. _titan_attention_forward_efficient: wraps SDPA in SDPBackend.EFFICIENT_ATTENTION,
   avoiding materialization of the QK^T matrix (saves ~26 GB for N=33k).

GPU float16 peak for representative N values:
  N=7k (median IMPACT):  1.2 GB  ✓
  N=33k:                26 GB   ✓
  N=45k:                49 GB   ✓
  N=62k (max observed): 92 GB   ✗ (needs Phase 2 / slide_max_patches guard)

Patches are applied lazily per-call via types.MethodType and do not modify
the upstream TITAN repository.

Unit tests: 13/13 pass (CPU, no model weights needed).
Integration test: submitted as SLURM job 3532188 (premium QOS) — pending.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
TITAN's VisionTransformer uses a CustomSequential that wraps blocks in
a .modules_list (nn.ModuleList), not directly iterable. Use .modules_list
when patching Attention.forward per block.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
- TITAN mlp_patch_embed_dim=768 → CONCH output must be 768-dim not 1024
- Diagonal coords created N×N-cell bounding box (900M cells for N=30k);
  use compact rectangular grid instead

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
…B copy

attn_bias.repeat(B, 1, 1, 1) creates a full copy of the (1, H, N+1, N+1) bias
tensor — 22 GB for N=30k — causing OOM even on A100 when added to the existing
22 GB bias. Replace with expand() (zero-copy view) in a new monkey-patch
_titan_forward_features_efficient that also inlines the expand into the flow.

Use .to(dtype, device) with no-op guard so no copy is made when the bias is
already float16 on the correct GPU (the common case with our get_alibi patch).

Also: use CONCH_DIM=768 and compact grid coords in integration test.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
…et_model_fun

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
The A100 OOM was from allocator fragmentation: 41 GB reserved but only 17 GB
free when trying to allocate 20 GB for QK^T attention matrix. Setting
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True allows the allocator to grow
existing segments instead of requiring contiguous free blocks.
Also calls empty_cache() after model load to clear any reserved-but-unused
memory before the large bias allocation.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Without SDPBackend.EFFICIENT_ATTENTION, the math kernel materializes
QK^T (22 GB) + QK^T+bias intermediate (22 GB) = 44 GB additional PER LAYER
during SDPA. With 6 layers and 22 GB bias, this easily exceeds A100's 80 GB.

EFFICIENT_ATTENTION processes attention in tiles without materializing QK^T,
reducing peak VRAM from 22→<1 GB for the attention computation per layer.
Memory budget: bias (22 GB) + model (2 GB) + tile intermediates (~1 GB) ≈ 25 GB.

Three monkey-patches now applied:
1. get_alibi → GPU float16 (eliminates 82 GB CPU RAM)
2. forward_features → expand() instead of repeat() (saves 22 GB GPU copy)
3. Attention.forward → EFFICIENT_ATTENTION (saves 44 GB per layer peak)

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
The previous test used N=30k with 100% foreground (no bg_mask), giving
N_fg=30k → bias=22 GB → tight on A100. Real IMPACT slides have 40-50%
background excluded by neural segmentation, so N_fg ≈ 15-18k in practice.

Updated test 2: 30k total patches × 60% tissue = N_fg=18k → bias=7.8 GB →
expected peak <50 GB. This accurately models real workloads.

Note: 100% foreground 30k-patch slides (N_fg=30k) still need Phase 2
(slide_max_patches=40k guard covers this extreme case).

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
…_fun

Monkey-patching block.attn.forward doesn't work in PyTorch (Module.__call__
bypasses instance attribute lookup in some paths). Instead, wrap the entire
inference call in sdpa_kernel(EFFICIENT_ATTENTION) context manager, which
forces all F.scaled_dot_product_attention calls to use the tiled xformers
kernel that doesn't materialize the full QK^T matrix.

This is simpler and more robust than per-block patching.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
3/3 pass on V100 — confirms all 3 patches work end-to-end:
- VRAM peak 3.2 GB for N=5k (vs 22 GB+ without EFFICIENT_ATTENTION context)
- CPU RAM delta 0.0 GB for N=10k (vs ~7 GB without GPU float16 get_alibi)

The large-N test (N=30k, A100 required) is in test_titan_gpu_integration.py
and validated mathematically for N_fg=18k (typical IMPACT foreground patches).

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
- Add TestTitanSlideEncoderModelFun and TestGigapathSlideEncoderModelFun
  unit tests to test_model_classes.py
- Add test_slide_encoder_matches_snapshot parametrized regression test
  to test_encoder_integration.py; restructured so assertion errors
  are not swallowed by _skip_on_load_failure (digit substrings like
  '401'/'403' in float array reprs were triggering false-positive skips)
- Add TITAN_SLIDE.npy golden snapshot generated from patched code
- Add flash-attn container development docs to README

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
The existing snapshot was generated with sdpa (no flash-attn installed).
flash_attention_2 produces slightly different values (~5e-4 max diff)
due to different tiling/accumulation — expected float16 behavior, not
a regression.

Re-generated on tllihpcgpu2 (A100, compute 8.0) inside mussel-fastattn.sif
with flash_attention_2 confirmed active. Verified PASSED on second run.

Also add slurm_conch_fa2_regression.sh and slurm_conch_fa2_regen.sh for
repeating this check on A100 hardware.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Runs TITAN slide encoder with identical synthetic input (N=32, seed=42)
under both unpatched (77d016d) and patched (HEAD) code on A100.

Results (job 3538760, tllihpcgpu2, A100 80GB):
  max abs diff  : 0.000732   (float16 precision noise)
  mean abs diff : 0.000159
  cosine sim    : 1.000000
  allclose(rtol=1e-2, atol=1e-3): True

PASS -- no numerical regression from the GPU float16 get_alibi patch,
expand() refactor, or EFFICIENT_ATTENTION wrapper.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Three issues found in code review:

1. Dead code (lines 290-338): Orphaned body of the abandoned per-block
   _titan_attention_forward_efficient patching strategy was stranded inside
   _titan_forward_features_efficient after its return statement. Deleted.

2. _titan_attention_forward_efficient was defined but never applied — the
   current approach uses a context manager in get_model_fun instead. Deleted
   the function and removed the corresponding test assertion that was checking
   for its existence rather than its effect.

3. os.environ.setdefault('PYTORCH_CUDA_ALLOC_CONF', ...) in get_model_fun
   had no effect: the CUDA allocator reads this env var once at initialization,
   before any model is loaded, so setting it here is too late. Removed the
   call; updated the docstring to document that it must be set in the process
   environment before Python starts.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
The previous code unconditionally created the sdpa_kernel(EFFICIENT_ATTENTION)
context manager at get_model_fun() time; the try/except only guarded the import
of sdpa_kernel, not the actual runtime use. On compute 6.1 (P40) and 7.0 (V100),
EFFICIENT_ATTENTION is not available, so sdpa_kernel raises RuntimeError at
forward-pass time: 'No viable backend for scaled_dot_product_attention'.

Fix: check torch.cuda.get_device_capability() before enabling EFFICIENT_ATTENTION,
mirroring the pattern already used in get_best_attn_implementation() in base.py.
On compute < 8.0, nullcontext is used (default SDPA kernel selection).

Also regenerate TITAN_SLIDE snapshot from P40 (default SDPA path).

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
- Add `-m "not integration"` to addopts so `pytest tests/` only runs
  unit tests; integration tests require `pytest -m integration`
- Add `_skip_if_no_testdata` skipif guard to the three patch-encoder
  tests that require 948176.svs / 948176.patch.h5 (graceful skip on
  machines without test slide data)
- Rename ad-hoc GPU scripts in tests/integration/ from test_*.py to
  run_*.py so pytest does not attempt to collect them (they caused
  INTERNALERROR due to module-level sys.exit and hardcoded paths)

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
…PatchApplied

Derive conch.py path from __file__ instead of absolute /gpfs path.
Also remove dead importlib.util lines (spec/mod were never executed).

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Deleted files contained hardcoded /gpfs paths pointing to internal
MSK cluster infrastructure, making them unsuitable for a public repo.
The regression results they produced are preserved in checkpoint history.

Removed:
- tests/integration/  (SLURM scripts for internal cluster)
- tests/regression/   (standalone scripts referencing internal REEF data)
- tests/mussel/test_wsi_pipeline_comparison.py  (refs internal ref pipeline)

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
…elopment sections

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
- Extract _get_slopes() from nested closure to module level so tests
  can import it directly; remove _get_slopes_ref duplicate from tests
- Move contextlib import to module top (was inside get_model_fun)
- Remove dead 'import math as _math' from _titan_forward_features_efficient
  (math already at module top; _math was never referenced)
- Tighten except clause: Exception -> (ImportError, RuntimeError)
- Guard EFFICIENT_ATTENTION on self.device.type == 'cuda' instead of
  torch.cuda.is_available() to avoid ValueError on CPU devices

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Our monkey-patches override get_alibi() and forward_features() on the
TITAN VisionTransformer.  Without a revision pin, from_pretrained()
always pulls main, meaning an upstream change to either method's
signature or logic could silently break the patches.

Pins to dac6773 (current main as of 2026-06-18), verified against our
numerical regression test (cosine sim 1.000000 vs unpatched baseline).

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
… and revision pin

- Gate monkey-patches in get_model_fun on getattr(self, '_patch_oom', True)
- When patch_oom=False: no revision pin, no get_alibi/forward_features patches
- Add **kwargs forwarding in ModelFactory and _SimpleModelFactory.get_model
- Add test_patch_oom_false_skips_monkey_patches to TestTitanSlideEncoderModelFun

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant