Skip to content

MPS CI Support#1278

Open
huseyincavusbi wants to merge 17 commits intoTransformerLensOrg:devfrom
huseyincavusbi:feat/mps-ci-support
Open

MPS CI Support#1278
huseyincavusbi wants to merge 17 commits intoTransformerLensOrg:devfrom
huseyincavusbi:feat/mps-ci-support

Conversation

@huseyincavusbi
Copy link
Copy Markdown
Contributor

@huseyincavusbi huseyincavusbi commented May 2, 2026

Hi @jlarson4,

This PR implements MPS (Metal Performance Shaders) CI Runner Support as proposed in #1264.

The goal is to provide automated testing for the Apple Silicon research community while working within the limits of GitHub's Mac runners.

Key Changes:

  • New Test Suite: Added tests/mps/test_mps_basic.py with 11 smoke tests covering device detection, core tensor ops on Metal, and HookedTransformer forward passes/caching with small models (TinyStories-1M).
  • CI Automation: Introduced the mps-checks job in .github/workflows/checks.yml. It uses macos-latest and runs only on PRs/pushes to main.
  • Memory Management:
    • Updated tests/conftest.py to proactively clear the MPS cache after every test using torch.mps.empty_cache().
    • Configured the CI to ignore memory-intensive modules (e.g., model_bridge) to ensure stability.
  • Opt-in Mechanism: Respects TRANSFORMERLENS_ALLOW_MPS=1 to ensure safe defaults for Mac users.

Type of change

  • New feature (non-breaking change which adds functionality)

Checklist:

  • I have commented my code, particularly in hard-to-understand areas
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes
  • I have not rewritten tests relating to key interfaces which would affect backward compatibility

brendanlong and others added 10 commits April 20, 2026 14:50
* Fix type of HookedTransformerConfig.device

This is typed as `Optional[str]` but sometimes returns `torch.device`.
Updated the code to just return the `str` instead of wrapping with a
device.

I'm not confident that every function which takes a device will
always be passed a string, so I didn't change functions like
warn_if_mps.

Found while working on TransformerLensOrg#1219

* more cleanup

* 3.0 CI Bugs (TransformerLensOrg#1261)

* Fixing `utils` imports

* skip gated notebooks on PR from forks

* Updating notebooks

* Ensure LLaMA only runs when HF_TOKEN is available

---------

Co-authored-by: jlarson4 <jonahalarson@comcast.net>
Copilot AI review requested due to automatic review settings May 2, 2026 16:03
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds Apple Silicon MPS coverage to CI by introducing an MPS-specific test suite and a macOS GitHub Actions job, alongside device-selection tweaks to make MPS opt-in by default.

Changes:

  • Added a new tests/mps smoke-test suite that validates basic tensor ops and a small HookedTransformer run on MPS.
  • Added an mps-checks GitHub Actions job on macos-latest to run unit/integration tests plus the new MPS smoke tests on PRs to main and pushes to main.
  • Updated device utilities and configs to better support MPS opt-in behavior, plus proactive torch.mps.empty_cache() cleanup in pytest fixtures.

Reviewed changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
transformer_lens/utilities/devices.py Adjusts device selection behavior/signatures to support MPS opt-in and updated warning typing.
transformer_lens/train.py Updates training config typing and default device assignment.
transformer_lens/config/HookedTransformerConfig.py Uses get_device() directly when defaulting cfg.device.
tests/unit/utilities/test_devices.py Updates device utility unit tests for the new get_device() return type.
tests/mps/test_mps_basic.py Adds MPS-only smoke tests covering device detection, core ops, and small-model forward/cache paths.
tests/mps/init.py Declares the MPS test package.
tests/conftest.py Adds MPS cache clearing after tests/classes/session to reduce CI OOM risk.
pyproject.toml Registers a no_mps pytest marker.
.github/workflows/checks.yml Adds the mps-checks CI job that runs on macOS and executes the MPS tests.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +56 to +67
def get_device() -> str:
"""Get the best available device, with MPS safety checks.

MPS is only auto-selected when the environment variable
``TRANSFORMERLENS_ALLOW_MPS=1`` is set **and** the installed PyTorch
version meets or exceeds ``_MPS_MIN_SAFE_TORCH_VERSION``.

Returns:
torch.device: The best available device (cuda, mps, or cpu)
str: The best available device name (cuda, mps, or cpu)
"""
if torch.cuda.is_available():
return torch.device("cuda")
return "cuda"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Roll back any changes related to get_device, there were some git history issues that caused this to crop up here for you.

I discussed this with another contributor and setup #1230 to standardize the type returned by get_device. There was a typing inconsistency, get_device() -> torch.device only returned a torch.device object in the default case, and returned a string in all other instances, creating a situation where the typing was silently wrong. We standardized everything to str, but for some reason that change never made it to dev, I have integrated it into this branch and it should no longer be a problem.

Comment thread transformer_lens/utilities/devices.py
Comment thread transformer_lens/train.py
Comment thread tests/mps/test_mps_basic.py Outdated
@huseyincavusbi
Copy link
Copy Markdown
Contributor Author

Hi @jlarson4, I've updated the PR to address the automated feedback:

  • API Stability: Reverted get_device() to return torch.device objects.
  • Type Checks: Updated type hints across model classes to resolve mypy failures.
  • CI Trigger: Strictly restricted mps-checks to the main branch

@jlarson4 jlarson4 mentioned this pull request May 4, 2026
7 tasks
print(f'MPS built: {torch.backends.mps.is_built()}')
assert torch.backends.mps.is_available(), 'MPS not available on this runner!'
"
- name: MPS Unit Tests
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sure you're passing $${{ secrets.HF_TOKEN }} to any tests that might load models so we can avoid HuggingFace rate limits where possible

- name: MPS Unit Tests
run: >
uv run pytest tests/unit -v
--ignore=tests/unit/model_bridge/
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are ignoring a lot of tests here, some should be able to run within our memory limit. Try ignoring only these tests and let me know if you found something while setting this up that necessitates skipping more tests.

- name: MPS Unit Tests
  run: >
    uv run pytest tests/unit -v
    --ignore=tests/unit/model_bridge/test_optimizer_compatibility.py
    --ignore=tests/unit/model_bridge/test_gpt_oss_moe.py
    --ignore=tests/unit/model_bridge/test_component_inspection.py
    --ignore=tests/unit/model_bridge/test_key_analysis.py
    --ignore=tests/unit/model_bridge/test_benchmark_gated_hooks_fire.py
    --ignore=tests/unit/model_bridge/test_weight_processing_adapter_paths.py
    --ignore=tests/unit/model_bridge/test_bridge_generate_kv_cache.py
    --ignore=tests/unit/model_bridge/test_bridge_vs_hooked_transformer_patching.py
    --ignore=tests/unit/model_bridge/compatibility/

--ignore=tests/unit/model_bridge/
- name: MPS Integration Tests
run: >
uv run pytest tests/integration -v
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should still be able to run grouped_query_attention and create_hooked_encoder that are ignored here.

Comment thread pyproject.toml
markers=["slow: marks tests as slow (deselect with '-m \"not slow\"')"]
markers=[
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
"no_mps: marks test as incompatible with MPS device (deselect with '-m \"not no_mps\"')",
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we aren't using this no_mps tag, we should drop it

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants