feat: tile embeddings pipeline with lazy parquet loading#7
feat: tile embeddings pipeline with lazy parquet loading#7vojtech-cifka wants to merge 49 commits intomasterfrom
Conversation
Adds preprocessing/embeddings.py with Virchow2 model wrapper and per-slide TileDataset that reads tiles from WSIs via OpenSlide. Tiles and slide metadata are fetched from the tiling MLflow run; embeddings are saved as per-slide parquet files and logged to MLflow. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Add base image, HF_TOKEN export, --frozen sync, and PROJECTS storage mount. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Consistent with how all other preprocessing run IDs are stored. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds an embeddings preprocessing pipeline: new configs, a distributed embedding implementation using Ray and an async model client, MLflow artifact wiring, a job submission script, and PyTorch-related dependencies. (50 words) ChangesEmbeddings Preprocessing Pipeline
Sequence DiagramsequenceDiagram
participant Runner as Runner
participant MLflow as MLflow
participant Storage as Storage
participant Ray as Ray
participant Model as AsyncClient
Runner->>MLflow: download "<split>_split" artifacts
Runner->>Storage: read slides.parquet and tiles.parquet
Runner->>Ray: build dataset from tiles.parquet and join slide metadata
Runner->>Ray: repartition by block_size and map EmbedTiles (ActorPoolStrategy)
Ray->>Model: embed_image(tile) [concurrent actors]
Model-->>Ray: return flattened embedding
Ray->>Storage: write slides.parquet and tiles/<embeddings>
Runner->>MLflow: log split directory as artifact
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip 💬 Introducing Slack Agent: The best way for teams to turn conversations into code.Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.
Built for teams:
One agent for your entire SDLC. Right inside Slack. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces a tile embedding preprocessing pipeline, featuring Hydra configurations for model selection and a Ray-based script for distributed tile processing. It also includes a job submission script and updates the project's deep learning dependencies. Review feedback identifies critical placeholder values in the submission script that must be replaced to avoid runtime errors and suggests relocating Ray initialization into the main function to improve modularity and testability.
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (2)
preprocessing/embeddings.py (2)
29-37: 🏗️ Heavy liftConsider adding error handling for API failures.
The
__call__method has no error handling. Given the scale (~80M tiles per PR notes) and the 200s timeout, transient network failures or service errors are likely. A single unhandled exception will cause the Ray task to fail and potentially retry the entire block.Consider adding retry logic with exponential backoff for transient errors:
♻️ Suggested retry wrapper
+import asyncio +from httpx import HTTPStatusError, TimeoutException + class EmbedTiles: def __init__(self, model: str, concurrency: int) -> None: self.model = model self.client = AsyncClient( limits=httpx.Limits( max_connections=concurrency, max_keepalive_connections=concurrency ), timeout=200, ) + self.max_retries = 3 async def __call__(self, row: dict[str, Any]) -> dict[str, Any]: - embedding = ( - (await self.client.models.embed_image(self.model, row["tile"])) - .reshape(-1) - .tolist() - ) + for attempt in range(self.max_retries): + try: + embedding = ( + (await self.client.models.embed_image(self.model, row["tile"])) + .reshape(-1) + .tolist() + ) + break + except (TimeoutException, HTTPStatusError) as e: + if attempt == self.max_retries - 1: + raise + await asyncio.sleep(2 ** attempt) del row["tile"] row["embedding"] = embedding return row🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@preprocessing/embeddings.py` around lines 29 - 37, The __call__ method needs retry/error handling around the call to self.client.models.embed_image(self.model, row["tile"]) to avoid failing the Ray task on transient API/network errors; wrap that call in a retry loop with exponential backoff (e.g., max_attempts, base_delay doubling each retry), catch transient exceptions (network errors, timeouts, 5xx responses), log each retry attempt, and re-raise only after max attempts; also only delete row["tile"] and assign row["embedding"] after a successful embed to avoid losing data if all retries fail.
64-67: 💤 Low valuePotential
KeyErrorif tile references unknown slide.The lambda assumes every
row["slide_id"]exists inslide_info. If a tile record references a slide not present inslides.parquet, this will raise aKeyErrorand fail the pipeline.If data integrity is guaranteed upstream (tiling pipeline), this is fine. Otherwise, consider defensive handling:
lambda row, si: {**row, **si.get(row["slide_id"], {})}Or validate at the start of processing that all slide IDs in tiles exist in slides.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@preprocessing/embeddings.py` around lines 64 - 67, The map lambda that merges tile rows with slide info will raise a KeyError if row["slide_id"] is missing from slide_info; update the lambda used in the map call (the inline lambda that references slide_info/si and row["slide_id"]) to defensively handle missing keys (e.g., look up slide_info with a safe-get and merge an empty dict when absent) or add a validation step before mapping that ensures all tile slide IDs exist in slide_info and fail early with a clear error; modify the lambda in the map invocation or add a precheck function to perform this validation.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@scripts/submit_embeddings.py`:
- Around line 4-18: The submit_job call uses placeholder values that break
execution: replace the Ellipsis passed to username in submit_job(...) with a
real username string or a clear placeholder like "<YOUR_USERNAME>", and update
the script list element that currently contains "+experiment=..." to either a
valid Hydra experiment path (e.g.,
"+experiment=preprocessing/embeddings_virchow2_05mpp") or a clear placeholder
like "+experiment=<EXPERIMENT_NAME>"; ensure these changes are applied to the
submit_job invocation and/or add a brief inline comment next to username and the
+experiment entry to indicate they must be filled before running.
---
Nitpick comments:
In `@preprocessing/embeddings.py`:
- Around line 29-37: The __call__ method needs retry/error handling around the
call to self.client.models.embed_image(self.model, row["tile"]) to avoid failing
the Ray task on transient API/network errors; wrap that call in a retry loop
with exponential backoff (e.g., max_attempts, base_delay doubling each retry),
catch transient exceptions (network errors, timeouts, 5xx responses), log each
retry attempt, and re-raise only after max attempts; also only delete
row["tile"] and assign row["embedding"] after a successful embed to avoid losing
data if all retries fail.
- Around line 64-67: The map lambda that merges tile rows with slide info will
raise a KeyError if row["slide_id"] is missing from slide_info; update the
lambda used in the map call (the inline lambda that references slide_info/si and
row["slide_id"]) to defensively handle missing keys (e.g., look up slide_info
with a safe-get and merge an empty dict when absent) or add a validation step
before mapping that ensures all tile slide IDs exist in slide_info and fail
early with a clear error; modify the lambda in the map invocation or add a
precheck function to perform this validation.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: ec9a215b-78b2-4847-bb29-2e83db6712d0
⛔ Files ignored due to path filters (1)
uv.lockis excluded by!**/*.lock
📒 Files selected for processing (7)
configs/data/dataset.yamlconfigs/experiment/preprocessing/embeddings_05mpp.yamlconfigs/experiment/preprocessing/embeddings_virchow2_05mpp.yamlconfigs/preprocessing/embeddings.yamlpreprocessing/embeddings.pypyproject.tomlscripts/submit_embeddings.py
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
preprocessing/embeddings.py (1)
86-91: 💤 Low valueConsider making actor pool parameters configurable.
max_size=4andmax_tasks_in_flight_per_actor=8are hardcoded whileconfig.concurrency(512) controls the HTTP connection pool. This creates a mismatch that could be confusing for tuning. The effective parallelism is: 4 actors × 8 tasks × async concurrency.Consider exposing
max_sizeandmax_tasks_in_flight_per_actorin the config for easier tuning without code changes.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@preprocessing/embeddings.py` around lines 86 - 91, The ActorPoolStrategy hardcodes max_size=4 and max_tasks_in_flight_per_actor=8 which mismatches config.concurrency; update the config object to add two new fields (e.g., actor_pool_max_size with default 4 and actor_pool_max_tasks_in_flight_per_actor with default 8) and replace the literals in the call to ray.data.ActorPoolStrategy inside embeddings.py (the compute=ray.data.ActorPoolStrategy(...) invocation) to use config.actor_pool_max_size and config.actor_pool_max_tasks_in_flight_per_actor so pool sizing is configurable and consistent with config.concurrency.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@preprocessing/embeddings.py`:
- Around line 19-37: EmbedTiles currently leaks httpx connections and lacks
error handling: add explicit async cleanup and handle embed_image failures.
Implement an async close/shutdown method on the EmbedTiles class that calls
self.client.aclose() (and call it from actor teardown), and add a fallback
__del__ that schedules closing to avoid lingering pools if shutdown isn’t
invoked; reference the AsyncClient instance self.client and the class
EmbedTiles. Wrap the embed_image call inside async __call__ in a try/except (or
a small retry loop) to catch network/service errors, log the exception and
either retry a few times or return a deterministic error marker in the row
(e.g., set row["error"] or row["embedding"]=None) instead of letting the
exception crash the actor; reference the __call__ method and the embed_image
invocation. Ensure you remove row["tile"] only after successful embedding to
avoid losing data on failure.
---
Nitpick comments:
In `@preprocessing/embeddings.py`:
- Around line 86-91: The ActorPoolStrategy hardcodes max_size=4 and
max_tasks_in_flight_per_actor=8 which mismatches config.concurrency; update the
config object to add two new fields (e.g., actor_pool_max_size with default 4
and actor_pool_max_tasks_in_flight_per_actor with default 8) and replace the
literals in the call to ray.data.ActorPoolStrategy inside embeddings.py (the
compute=ray.data.ActorPoolStrategy(...) invocation) to use
config.actor_pool_max_size and config.actor_pool_max_tasks_in_flight_per_actor
so pool sizing is configurable and consistent with config.concurrency.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: c5d9d565-abb6-4546-8a58-a94ad917ef50
📒 Files selected for processing (1)
preprocessing/embeddings.py
0bf7ff5 to
606b8e4
Compare
Resolved conflicts in configs/data/dataset.yaml and pyproject.toml by keeping both sets of additions: tissue_masks_run_id from master and torch/torchvision/timm/einops deps from this branch. Regenerated uv.lock.
Wraps the embed_image call in a retry loop (up to 3 attempts with exponential backoff) so transient network errors don't cause Ray to retry the entire block. Re-raises after all attempts are exhausted so failures stay visible. Moves del row["tile"] into a finally block to free tile pixel data promptly even when an exception occurs.
Skip tiles with no annotation coverage and no tissue coverage before feeding them into the Ray pipeline, using PyArrow predicate pushdown to avoid materialising the full 80M-row dataset in memory. Tissue stats run ID stored in dataset.yaml; referenced from the embeddings config via tile_filters. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
pc.or_ is an array kernel; expression combination uses the | operator. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…tering Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
… splits Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Ray actor logs and progress bars provide cold-start visibility; the bespoke first_row_logged flag was a tuning aid no longer needed. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Replaces hand-rolled retry loop in EmbedTiles with a tenacity-decorated helper. Retries are now scoped to httpx.HTTPError (network/timeout/status) so programming bugs surface immediately instead of being retried 3x. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Ray Data progress bars surface rows/sec and in-flight counts, so the periodic counter log and the latency/in_flight bookkeeping it required are redundant. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Reading three projected columns doesn't need 8 GB; default scheduling lets Ray pack more readers per node. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
The @Retry decorator at class scope captures an AsyncRetrying instance whose internal threading.local() makes EmbedTiles unpicklable, which breaks Ray Data's actor serialization. Constructing the retryer in __init__ moves that state onto the worker. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Without the 8 GB memory reservation on read_parquet, downstream stages stall and no embeddings are produced. Restoring until we understand the scheduling interaction. This reverts commit 287408f. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…ors" Tenacity-based retries broke embedding production under high actor concurrency (likely shared AsyncRetrying state and narrowed exception filter dropping retryable non-httpx errors). Restoring the manual loop, which was last known to work. This reverts commit 82163ac. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Restore the last-known-working version of EmbedTiles after several review-driven refactors broke embedding production. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Three-column projection doesn't justify 8 GB per task; let Ray's default scheduling pack readers per node. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Build AsyncRetrying in __init__ so the actor stays picklable (threading.local inside tenacity can't cross the wire if captured at class scope). Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…ors" This reverts commit ea437e5.
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Summary
(annotation- and tissue-filtered tiles) rather than raw tiling output
read_slide_tiles, and dispatch async embed calls through an actor pool
output)
Summary by CodeRabbit
New Features
Chores