Add MLX backend support for Nutpie compilation#254
Conversation
Introduces MLX as a backend option in compile_pymc_model, allowing gradient computation via MLX or Pytensor. Updates dependency groups to include MLX, extends internal functions to handle MLX mode, and adds corresponding tests for deterministic sampling with MLX.
|
Thanks, that looks great! |
Bump MLX version requirement to >=0.29.0 in pyproject.toml for all relevant extras. In compile_pymc.py, JIT compile the logp function using mx.compile for improved performance, aligning with JAX backend behavior.
Good point, that simple addition brings between 5% to 20% more performance! @aseyboldt |
|
@aseyboldt solve the test issue to work only on macs with intel chips. |
|
@aseyboldt can you give me a hand? The test failing its strange. My local pass everythig. |
aseyboldt
left a comment
There was a problem hiding this comment.
That failure is annoying. For some reason the results seem to differ between different machines? I think we really should figure out what's going on here. Maybe it helps if we print the first couple of values in warmup_posterior to see if the initial values are already different, or if small differences accumulate?
| updated.update(**updates) | ||
|
|
||
| # Convert to MLX arrays if using MLX backend (indicated by force_single_core) | ||
| if self._force_single_core: |
There was a problem hiding this comment.
We should not use that argument to detect mlx.
How about we add an attribute _convert_data_item or so to the dataclass, that contains a function that transforms data arrays? We could then also use that for jax.
Resolve conflicts to keep the MLX backend goals while adopting upstream fixes and CI restructure: - compiled_pyfunc.py: combine upstream's new extra_callback / extra_callback_rate parameters in PyFuncModel._make_sampler with the MLX force_single_core guard. - .github/workflows/ci.yml: adopt upstream's split build/test jobs with the suite matrix (stan/pymc/flow), test_pymc_dev, docs, deploy-docs. Install mlx only in the macOS pymc suite on aarch64 (x86_64 macOS pymc is excluded per upstream's matrix). - compile_pymc.py auto-merge: keep _compile_pymc_model_mlx, MLX backend in compile_pymc_model and _make_functions, and adopt upstream's PyTensor compat refactor (pt.grad, allow_xtensor_conversion, pytensor imports). - tests/test_pymc.py auto-merge: keep dynamic backend_params with MLX guarded by find_spec, plus upstream's test_progress_callback. Made-with: Cursor
Address review comments and CI failures on PR pymc-devs#254: * MLX is not thread-safe (Metal command-buffer race, ml-explore/mlx#2133), so always set ``force_single_core=True`` for ``backend="mlx"`` regardless of ``gradient_backend``. This unblocks the default config (``gradient_backend="pytensor"``) with ``chains>=2``, which previously aborted with "A command encoder is already encoding to this command buffer". * Decouple shared-data type conversion from ``_force_single_core``: ``PyFuncModel`` now carries an optional ``_shared_data_converter`` callable (set to ``mx.array`` for MLX) used by ``with_data``. The old code was abusing ``_force_single_core`` as a "is MLX backend" proxy. * Drop the stale ``raw_logp_fn`` plumbing in the MLX backend. The transform adapter is flowjax-based (JAX-only), so MLX could never expose a usable raw logp. * Pin ``mlx<0.31`` in ``pyproject.toml`` extras and the CI install. mlx 0.31.x crashes inside ``Compiled::eval_gpu`` when the sampler worker thread evaluates auto-fused element-wise kernels (ml-explore/mlx#3329), which is the root cause of the macOS aarch64 pymc test segfaults. mlx 0.29.x and 0.30.x are unaffected. * Add a runtime guard in ``_compile_pymc_model_mlx`` that raises a helpful ``RuntimeError`` if mlx>=0.31 is installed anyway, instead of segfaulting partway through sampling. * Make ``test_deterministic_sampling_mlx`` a smoke test (drop ``array_compare`` + reference file). MLX sampling is not bit-identical across machines/MLX versions, mirroring the situation already documented for ``test_normalizing_flow``. * Mark ``test_dims_model[mlx-*]`` as ``xfail`` while pymc-devs/pytensor#1350 is open (PyTensor's MLX linker has no ``XTensorFromTensor`` op yet). * Mirror logp's ``mx.array(_x)`` conversion in the expand closure for consistency. * Ignore local AI tooling folders (``.ai/``, ``.cursor/``, ``.claude/``). Made-with: Cursor
Keep only the upstream-issue references (mlx#2133, mlx#3329) since those are non-obvious context; drop the explanatory comments whose content is already conveyed by the surrounding code. Made-with: Cursor
…mlx_backend Co-authored-by: Cursor <cursoragent@cursor.com> # Conflicts: # python/nutpie/compiled_pyfunc.py
Deep-dive into the cross-machine non-determinism flagged in review: the MLX backend evaluates the logp/gradient on the Metal GPU, where float64 is unsupported, so ``mx.array(_x)`` silently downcasts nutpie's float64 positions to float32. Measured relative error is ~1e-7 (float32 eps), and because NUTS is a chaotic integrator that error accumulates across the leapfrog steps and varies with GPU/Metal/kernel fusion -- which is why the array_compare reference test could never match between machines. * Cast inputs explicitly to ``mx.float32`` in the logp and expand closures so the precision contract is intentional and cannot silently change with the MLX version or a global default-dtype override, and document the root cause at the conversion site. * Replace the bare smoke test with a machine-independent statistical correctness check: sample HalfNormal and assert the posterior recovers the analytic moments within ~4x the Monte Carlo error. * xfail ``test_pymc_model_store_extra[mlx-pytensor]``: PyTensor's MLX linker Split dispatch passes a numpy array to ``mx.split`` (read as an equal-split count) and breaks on the uneven splits in the ZeroSumNormal/Dirichlet gradient graphs. mlx/mlx is unaffected. Made-with: Cursor
The pymc test suite (and pymc-dev) started crashing with SIGSEGV on Linux/Windows. Reproduced in a clean uv venv matching CI exactly: the fault is in native pandas (e.g. pd.date_range in the unmodified test_zarr_store), not in nutpie or the MLX backend. Root cause: pandas 3.0 wheels are compiled against the numpy 2.5 C ABI but their metadata only requires numpy>=1.26, while numba caps numpy<2.5. uv therefore resolves pandas 3.0.4 together with numpy 2.4.6, and the ABI mismatch segfaults in pandas' compiled extensions. Pinning pandas<3.0 (which is built against the numpy 2.x ABI) resolves it; the full pymc suite then passes. Mirrors the existing arviz<1.0.0 pin. Drop once numba supports numpy>=2.5. Made-with: Cursor
The pymc (dev) and docs jobs installed the locally-built wheel first (--no-deps --no-index), then ran a second `uv pip install nutpie[extra]` without --find-links dist. That second resolution pulled nutpie from PyPI and downgraded it to a release predating the newly-added extras: nutpie 0.13.4 has no 'dev' extra (pytest never installed -> "pytest: command not found") and 0.16.8 has no 'docs' extra (jupyter/nbformat never installed -> `quarto render` fails with ModuleNotFoundError). Add --find-links dist to both extras installs so the local wheel stays the preferred (highest-version) candidate and the new extras resolve against it. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
pymc (dev) and docs both resolve nutpie[extra] under uv-constraints-dev.txt, which pinned BOTH pymc@main and pytensor@main. pymc main caps pytensor (currently >=3.0.7,<3.1) while pytensor main is >=3.1, so the two git-main pins are mutually unsatisfiable. uv resolves around it by backtracking nutpie to an old PyPI release that has no dev/docs extra, so pytest/jupyter are never installed (pytest: command not found; nbformat ModuleNotFoundError). Pin only pymc@main and let it pull the released pytensor it actually supports (3.0.7). Verified with `uv pip compile`: 159 packages resolve cleanly. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Introduces MLX as a backend option in compile_pymc_model, allowing gradient computation via MLX or Pytensor. Updates dependency groups to include MLX, extends internal functions to handle MLX mode, and adds corresponding tests for deterministic sampling with MLX.