Skip to content

PERF: Enable BlasOpt and dispatch gemm/gemv/ger to fused MLX kernels #2213

Description

@drbenvincent

Description

Raised in a comment on #2085.

Background

PyTensor has a graph rewrite pass called BlasOpt that recognises common linear algebra patterns and replaces them with fused BLAS-level operations. It is already enabled for the NumPy and Numba backends. It does not appear to be enabled for the MLX backend, and the fused kernel dispatches are not registered.

The three operations in scope:

  • Gemm — fused matrix-matrix multiply with scaling and accumulation: alpha * A @ B + beta * C
  • Gemv — same pattern for matrix × vector Add MLX GEMM dispatch #2008
  • Ger — rank-1 update (outer product): alpha * x ⊗ y + A

Without BlasOpt, a computation like 0.5 * (A @ B) + C compiles to three separate Metal dispatches: matmul → scale → add. With a fused Gemm it is a single kernel call.

Why this matters for PyMC/MCMC workloads

Matrix operations appear in the hot path for:

  • Multivariate Normal and LKJ-Cholesky likelihoods
  • Gaussian Process covariance computations
  • Any model with a dot product in the linear predictor (hierarchical regression, MMM)
  • Time series models with state-space structure

These are not niche cases. Most non-trivial PyMC models will benefit.

Proposed work

1. Register fused dispatches in pytensor/link/mlx/dispatch/blas.py

MLX has mx.addmm which computes beta * C + alpha * (A @ B) — a direct match for Gemm:

@mlx_funcify.register(Gemm)
def mlx_funcify_Gemm(op, **kwargs):
    def gemm(A, B, C, alpha, beta):
        return mx.addmm(C, A, B, alpha=alpha, beta=beta)
    return gemm

For Gemv, mx.addmm also handles the matrix-vector case (MLX treats vectors as rank-1 matrices). For Ger, MLX has mx.outer; a fused version may require a small custom implementation or mx.addmm with reshaped inputs.

2. Enable BlasOpt in the MLX linker/optimizer

In pytensor/link/mlx/linker.py (or wherever the MLX compilation mode is defined), add BlasOpt to the optimisation sequence, mirroring how it is registered for other backends.

3. Verify numerics and benchmark

Compare output against the NumPy backend on a representative model (e.g. multivariate normal logp with a 64×64 covariance matrix). Measure before/after on at least Gemm — expected speedup is meaningful wherever the fused pattern appears, since it eliminates one or two extra kernel dispatches per matrix operation.

Effort

Low-to-medium. The rewrite pass already exists; this is dispatch registration and linker configuration, not new algorithm work. The main caution is verifying that alpha/beta scalar handling matches what PyTensor’s Gemm Op expects, and confirming mx.addmm behaviour on edge cases (beta=0, in-place accumulation).

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Fields

    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions