MPS CI Support#1278
Conversation
* Fix type of HookedTransformerConfig.device This is typed as `Optional[str]` but sometimes returns `torch.device`. Updated the code to just return the `str` instead of wrapping with a device. I'm not confident that every function which takes a device will always be passed a string, so I didn't change functions like warn_if_mps. Found while working on TransformerLensOrg#1219 * more cleanup * 3.0 CI Bugs (TransformerLensOrg#1261) * Fixing `utils` imports * skip gated notebooks on PR from forks * Updating notebooks * Ensure LLaMA only runs when HF_TOKEN is available --------- Co-authored-by: jlarson4 <jonahalarson@comcast.net>
TransformerLens 3.1.0
There was a problem hiding this comment.
Pull request overview
Adds Apple Silicon MPS coverage to CI by introducing an MPS-specific test suite and a macOS GitHub Actions job, alongside device-selection tweaks to make MPS opt-in by default.
Changes:
- Added a new
tests/mpssmoke-test suite that validates basic tensor ops and a smallHookedTransformerrun on MPS. - Added an
mps-checksGitHub Actions job onmacos-latestto run unit/integration tests plus the new MPS smoke tests on PRs tomainand pushes tomain. - Updated device utilities and configs to better support MPS opt-in behavior, plus proactive
torch.mps.empty_cache()cleanup in pytest fixtures.
Reviewed changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| transformer_lens/utilities/devices.py | Adjusts device selection behavior/signatures to support MPS opt-in and updated warning typing. |
| transformer_lens/train.py | Updates training config typing and default device assignment. |
| transformer_lens/config/HookedTransformerConfig.py | Uses get_device() directly when defaulting cfg.device. |
| tests/unit/utilities/test_devices.py | Updates device utility unit tests for the new get_device() return type. |
| tests/mps/test_mps_basic.py | Adds MPS-only smoke tests covering device detection, core ops, and small-model forward/cache paths. |
| tests/mps/init.py | Declares the MPS test package. |
| tests/conftest.py | Adds MPS cache clearing after tests/classes/session to reduce CI OOM risk. |
| pyproject.toml | Registers a no_mps pytest marker. |
| .github/workflows/checks.yml | Adds the mps-checks CI job that runs on macOS and executes the MPS tests. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def get_device() -> str: | ||
| """Get the best available device, with MPS safety checks. | ||
|
|
||
| MPS is only auto-selected when the environment variable | ||
| ``TRANSFORMERLENS_ALLOW_MPS=1`` is set **and** the installed PyTorch | ||
| version meets or exceeds ``_MPS_MIN_SAFE_TORCH_VERSION``. | ||
|
|
||
| Returns: | ||
| torch.device: The best available device (cuda, mps, or cpu) | ||
| str: The best available device name (cuda, mps, or cpu) | ||
| """ | ||
| if torch.cuda.is_available(): | ||
| return torch.device("cuda") | ||
| return "cuda" |
There was a problem hiding this comment.
Roll back any changes related to get_device, there were some git history issues that caused this to crop up here for you.
I discussed this with another contributor and setup #1230 to standardize the type returned by get_device. There was a typing inconsistency, get_device() -> torch.device only returned a torch.device object in the default case, and returned a string in all other instances, creating a situation where the typing was silently wrong. We standardized everything to str, but for some reason that change never made it to dev, I have integrated it into this branch and it should no longer be a problem.
|
Hi @jlarson4, I've updated the PR to address the automated feedback:
|
| print(f'MPS built: {torch.backends.mps.is_built()}') | ||
| assert torch.backends.mps.is_available(), 'MPS not available on this runner!' | ||
| " | ||
| - name: MPS Unit Tests |
There was a problem hiding this comment.
Make sure you're passing $${{ secrets.HF_TOKEN }} to any tests that might load models so we can avoid HuggingFace rate limits where possible
| - name: MPS Unit Tests | ||
| run: > | ||
| uv run pytest tests/unit -v | ||
| --ignore=tests/unit/model_bridge/ |
There was a problem hiding this comment.
We are ignoring a lot of tests here, some should be able to run within our memory limit. Try ignoring only these tests and let me know if you found something while setting this up that necessitates skipping more tests.
- name: MPS Unit Tests
run: >
uv run pytest tests/unit -v
--ignore=tests/unit/model_bridge/test_optimizer_compatibility.py
--ignore=tests/unit/model_bridge/test_gpt_oss_moe.py
--ignore=tests/unit/model_bridge/test_component_inspection.py
--ignore=tests/unit/model_bridge/test_key_analysis.py
--ignore=tests/unit/model_bridge/test_benchmark_gated_hooks_fire.py
--ignore=tests/unit/model_bridge/test_weight_processing_adapter_paths.py
--ignore=tests/unit/model_bridge/test_bridge_generate_kv_cache.py
--ignore=tests/unit/model_bridge/test_bridge_vs_hooked_transformer_patching.py
--ignore=tests/unit/model_bridge/compatibility/
| --ignore=tests/unit/model_bridge/ | ||
| - name: MPS Integration Tests | ||
| run: > | ||
| uv run pytest tests/integration -v |
There was a problem hiding this comment.
We should still be able to run grouped_query_attention and create_hooked_encoder that are ignored here.
| markers=["slow: marks tests as slow (deselect with '-m \"not slow\"')"] | ||
| markers=[ | ||
| "slow: marks tests as slow (deselect with '-m \"not slow\"')", | ||
| "no_mps: marks test as incompatible with MPS device (deselect with '-m \"not no_mps\"')", |
There was a problem hiding this comment.
If we aren't using this no_mps tag, we should drop it
Hi @jlarson4,
This PR implements MPS (Metal Performance Shaders) CI Runner Support as proposed in #1264.
The goal is to provide automated testing for the Apple Silicon research community while working within the limits of GitHub's Mac runners.
Key Changes:
tests/mps/test_mps_basic.pywith 11 smoke tests covering device detection, core tensor ops on Metal, andHookedTransformerforward passes/caching with small models (TinyStories-1M).mps-checksjob in.github/workflows/checks.yml. It usesmacos-latestand runs only on PRs/pushes to main.tests/conftest.pyto proactively clear the MPS cache after every test usingtorch.mps.empty_cache().model_bridge) to ensure stability.TRANSFORMERLENS_ALLOW_MPS=1to ensure safe defaults for Mac users.Type of change
Checklist: