Skip to content

Add MLX backend support for Nutpie compilation#254

Open
cetagostini wants to merge 16 commits into
pymc-devs:mainfrom
cetagostini:cetagostini/adding_mlx_backend
Open

Add MLX backend support for Nutpie compilation#254
cetagostini wants to merge 16 commits into
pymc-devs:mainfrom
cetagostini:cetagostini/adding_mlx_backend

Conversation

@cetagostini

Copy link
Copy Markdown

Introduces MLX as a backend option in compile_pymc_model, allowing gradient computation via MLX or Pytensor. Updates dependency groups to include MLX, extends internal functions to handle MLX mode, and adds corresponding tests for deterministic sampling with MLX.

Introduces MLX as a backend option in compile_pymc_model, allowing gradient computation via MLX or Pytensor. Updates dependency groups to include MLX, extends internal functions to handle MLX mode, and adds corresponding tests for deterministic sampling with MLX.
@aseyboldt

Copy link
Copy Markdown
Member

Thanks, that looks great!
I think we probably should call mlx.compile on the final functions though?

Bump MLX version requirement to >=0.29.0 in pyproject.toml for all relevant extras. In compile_pymc.py, JIT compile the logp function using mx.compile for improved performance, aligning with JAX backend behavior.
@cetagostini

cetagostini commented Oct 27, 2025

Copy link
Copy Markdown
Author

Thanks, that looks great! I think we probably should call mlx.compile on the final functions though?

Good point, that simple addition brings between 5% to 20% more performance! @aseyboldt

@cetagostini

Copy link
Copy Markdown
Author

@aseyboldt solve the test issue to work only on macs with intel chips.

@cetagostini

Copy link
Copy Markdown
Author

@aseyboldt can you give me a hand? The test failing its strange. My local pass everythig.

@cetagostini cetagostini requested review from aseyboldt and jessegrabowski and removed request for aseyboldt October 30, 2025 12:46

@aseyboldt aseyboldt left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That failure is annoying. For some reason the results seem to differ between different machines? I think we really should figure out what's going on here. Maybe it helps if we print the first couple of values in warmup_posterior to see if the initial values are already different, or if small differences accumulate?

Comment thread python/nutpie/compiled_pyfunc.py Outdated
updated.update(**updates)

# Convert to MLX arrays if using MLX backend (indicated by force_single_core)
if self._force_single_core:

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not use that argument to detect mlx.
How about we add an attribute _convert_data_item or so to the dataclass, that contains a function that transforms data arrays? We could then also use that for jax.

cetagostini and others added 8 commits April 28, 2026 13:36
Resolve conflicts to keep the MLX backend goals while adopting upstream
fixes and CI restructure:

- compiled_pyfunc.py: combine upstream's new extra_callback / extra_callback_rate
  parameters in PyFuncModel._make_sampler with the MLX force_single_core
  guard.
- .github/workflows/ci.yml: adopt upstream's split build/test jobs with
  the suite matrix (stan/pymc/flow), test_pymc_dev, docs, deploy-docs.
  Install mlx only in the macOS pymc suite on aarch64 (x86_64 macOS
  pymc is excluded per upstream's matrix).
- compile_pymc.py auto-merge: keep _compile_pymc_model_mlx, MLX backend
  in compile_pymc_model and _make_functions, and adopt upstream's
  PyTensor compat refactor (pt.grad, allow_xtensor_conversion, pytensor
  imports).
- tests/test_pymc.py auto-merge: keep dynamic backend_params with MLX
  guarded by find_spec, plus upstream's test_progress_callback.

Made-with: Cursor
Address review comments and CI failures on PR pymc-devs#254:

* MLX is not thread-safe (Metal command-buffer race,
  ml-explore/mlx#2133), so always set ``force_single_core=True`` for
  ``backend="mlx"`` regardless of ``gradient_backend``. This unblocks the
  default config (``gradient_backend="pytensor"``) with ``chains>=2``,
  which previously aborted with
  "A command encoder is already encoding to this command buffer".

* Decouple shared-data type conversion from ``_force_single_core``:
  ``PyFuncModel`` now carries an optional ``_shared_data_converter``
  callable (set to ``mx.array`` for MLX) used by ``with_data``. The old
  code was abusing ``_force_single_core`` as a "is MLX backend" proxy.

* Drop the stale ``raw_logp_fn`` plumbing in the MLX backend. The
  transform adapter is flowjax-based (JAX-only), so MLX could never
  expose a usable raw logp.

* Pin ``mlx<0.31`` in ``pyproject.toml`` extras and the CI install.
  mlx 0.31.x crashes inside ``Compiled::eval_gpu`` when the sampler
  worker thread evaluates auto-fused element-wise kernels
  (ml-explore/mlx#3329), which is the root cause of the macOS aarch64
  pymc test segfaults. mlx 0.29.x and 0.30.x are unaffected.

* Add a runtime guard in ``_compile_pymc_model_mlx`` that raises a
  helpful ``RuntimeError`` if mlx>=0.31 is installed anyway, instead of
  segfaulting partway through sampling.

* Make ``test_deterministic_sampling_mlx`` a smoke test (drop
  ``array_compare`` + reference file). MLX sampling is not bit-identical
  across machines/MLX versions, mirroring the situation already
  documented for ``test_normalizing_flow``.

* Mark ``test_dims_model[mlx-*]`` as ``xfail`` while
  pymc-devs/pytensor#1350 is open (PyTensor's MLX linker has no
  ``XTensorFromTensor`` op yet).

* Mirror logp's ``mx.array(_x)`` conversion in the expand closure for
  consistency.

* Ignore local AI tooling folders (``.ai/``, ``.cursor/``, ``.claude/``).

Made-with: Cursor
Keep only the upstream-issue references (mlx#2133, mlx#3329) since
those are non-obvious context; drop the explanatory comments whose
content is already conveyed by the surrounding code.

Made-with: Cursor
…mlx_backend

Co-authored-by: Cursor <cursoragent@cursor.com>

# Conflicts:
#	python/nutpie/compiled_pyfunc.py
Deep-dive into the cross-machine non-determinism flagged in review: the
MLX backend evaluates the logp/gradient on the Metal GPU, where float64
is unsupported, so ``mx.array(_x)`` silently downcasts nutpie's float64
positions to float32. Measured relative error is ~1e-7 (float32 eps), and
because NUTS is a chaotic integrator that error accumulates across the
leapfrog steps and varies with GPU/Metal/kernel fusion -- which is why the
array_compare reference test could never match between machines.

* Cast inputs explicitly to ``mx.float32`` in the logp and expand closures
  so the precision contract is intentional and cannot silently change with
  the MLX version or a global default-dtype override, and document the
  root cause at the conversion site.

* Replace the bare smoke test with a machine-independent statistical
  correctness check: sample HalfNormal and assert the posterior recovers
  the analytic moments within ~4x the Monte Carlo error.

* xfail ``test_pymc_model_store_extra[mlx-pytensor]``: PyTensor's MLX
  linker Split dispatch passes a numpy array to ``mx.split`` (read as an
  equal-split count) and breaks on the uneven splits in the
  ZeroSumNormal/Dirichlet gradient graphs. mlx/mlx is unaffected.

Made-with: Cursor
The pymc test suite (and pymc-dev) started crashing with SIGSEGV on
Linux/Windows. Reproduced in a clean uv venv matching CI exactly: the
fault is in native pandas (e.g. pd.date_range in the unmodified
test_zarr_store), not in nutpie or the MLX backend.

Root cause: pandas 3.0 wheels are compiled against the numpy 2.5 C ABI
but their metadata only requires numpy>=1.26, while numba caps numpy<2.5.
uv therefore resolves pandas 3.0.4 together with numpy 2.4.6, and the ABI
mismatch segfaults in pandas' compiled extensions. Pinning pandas<3.0
(which is built against the numpy 2.x ABI) resolves it; the full pymc
suite then passes. Mirrors the existing arviz<1.0.0 pin. Drop once numba
supports numpy>=2.5.

Made-with: Cursor
The pymc (dev) and docs jobs installed the locally-built wheel first
(--no-deps --no-index), then ran a second `uv pip install nutpie[extra]`
without --find-links dist. That second resolution pulled nutpie from
PyPI and downgraded it to a release predating the newly-added extras:
nutpie 0.13.4 has no 'dev' extra (pytest never installed -> "pytest:
command not found") and 0.16.8 has no 'docs' extra (jupyter/nbformat
never installed -> `quarto render` fails with ModuleNotFoundError).

Add --find-links dist to both extras installs so the local wheel stays
the preferred (highest-version) candidate and the new extras resolve
against it.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
pymc (dev) and docs both resolve nutpie[extra] under uv-constraints-dev.txt,
which pinned BOTH pymc@main and pytensor@main. pymc main caps pytensor
(currently >=3.0.7,<3.1) while pytensor main is >=3.1, so the two git-main
pins are mutually unsatisfiable. uv resolves around it by backtracking nutpie
to an old PyPI release that has no dev/docs extra, so pytest/jupyter are never
installed (pytest: command not found; nbformat ModuleNotFoundError).

Pin only pymc@main and let it pull the released pytensor it actually supports
(3.0.7). Verified with `uv pip compile`: 159 packages resolve cleanly.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants