Description
Raised in a comment on #2085.
Background
PyTensor has a graph rewrite pass called BlasOpt that recognises common linear algebra patterns and replaces them with fused BLAS-level operations. It is already enabled for the NumPy and Numba backends. It does not appear to be enabled for the MLX backend, and the fused kernel dispatches are not registered.
The three operations in scope:
Without BlasOpt, a computation like 0.5 * (A @ B) + C compiles to three separate Metal dispatches: matmul → scale → add. With a fused Gemm it is a single kernel call.
Why this matters for PyMC/MCMC workloads
Matrix operations appear in the hot path for:
- Multivariate Normal and LKJ-Cholesky likelihoods
- Gaussian Process covariance computations
- Any model with a dot product in the linear predictor (hierarchical regression, MMM)
- Time series models with state-space structure
These are not niche cases. Most non-trivial PyMC models will benefit.
Proposed work
1. Register fused dispatches in pytensor/link/mlx/dispatch/blas.py
MLX has mx.addmm which computes beta * C + alpha * (A @ B) — a direct match for Gemm:
@mlx_funcify.register(Gemm)
def mlx_funcify_Gemm(op, **kwargs):
def gemm(A, B, C, alpha, beta):
return mx.addmm(C, A, B, alpha=alpha, beta=beta)
return gemm
For Gemv, mx.addmm also handles the matrix-vector case (MLX treats vectors as rank-1 matrices). For Ger, MLX has mx.outer; a fused version may require a small custom implementation or mx.addmm with reshaped inputs.
2. Enable BlasOpt in the MLX linker/optimizer
In pytensor/link/mlx/linker.py (or wherever the MLX compilation mode is defined), add BlasOpt to the optimisation sequence, mirroring how it is registered for other backends.
3. Verify numerics and benchmark
Compare output against the NumPy backend on a representative model (e.g. multivariate normal logp with a 64×64 covariance matrix). Measure before/after on at least Gemm — expected speedup is meaningful wherever the fused pattern appears, since it eliminates one or two extra kernel dispatches per matrix operation.
Effort
Low-to-medium. The rewrite pass already exists; this is dispatch registration and linker configuration, not new algorithm work. The main caution is verifying that alpha/beta scalar handling matches what PyTensor’s Gemm Op expects, and confirming mx.addmm behaviour on edge cases (beta=0, in-place accumulation).
Description
Raised in a comment on #2085.
Background
PyTensor has a graph rewrite pass called
BlasOptthat recognises common linear algebra patterns and replaces them with fused BLAS-level operations. It is already enabled for the NumPy and Numba backends. It does not appear to be enabled for the MLX backend, and the fused kernel dispatches are not registered.The three operations in scope:
Gemm— fused matrix-matrix multiply with scaling and accumulation:alpha * A @ B + beta * CGemv— same pattern for matrix × vector Add MLX GEMM dispatch #2008Ger— rank-1 update (outer product):alpha * x ⊗ y + AWithout BlasOpt, a computation like
0.5 * (A @ B) + Ccompiles to three separate Metal dispatches: matmul → scale → add. With a fusedGemmit is a single kernel call.Why this matters for PyMC/MCMC workloads
Matrix operations appear in the hot path for:
These are not niche cases. Most non-trivial PyMC models will benefit.
Proposed work
1. Register fused dispatches in
pytensor/link/mlx/dispatch/blas.pyMLX has
mx.addmmwhich computesbeta * C + alpha * (A @ B)— a direct match forGemm:For
Gemv,mx.addmmalso handles the matrix-vector case (MLX treats vectors as rank-1 matrices). ForGer, MLX hasmx.outer; a fused version may require a small custom implementation ormx.addmmwith reshaped inputs.2. Enable BlasOpt in the MLX linker/optimizer
In
pytensor/link/mlx/linker.py(or wherever the MLX compilation mode is defined), addBlasOptto the optimisation sequence, mirroring how it is registered for other backends.3. Verify numerics and benchmark
Compare output against the NumPy backend on a representative model (e.g. multivariate normal logp with a 64×64 covariance matrix). Measure before/after on at least
Gemm— expected speedup is meaningful wherever the fused pattern appears, since it eliminates one or two extra kernel dispatches per matrix operation.Effort
Low-to-medium. The rewrite pass already exists; this is dispatch registration and linker configuration, not new algorithm work. The main caution is verifying that
alpha/betascalar handling matches what PyTensor’sGemmOp expects, and confirmingmx.addmmbehaviour on edge cases (beta=0, in-place accumulation).