diff --git a/pytensor/link/mlx/dispatch/signal/conv.py b/pytensor/link/mlx/dispatch/signal/conv.py index 481457d4da..eafc0bb87b 100644 --- a/pytensor/link/mlx/dispatch/signal/conv.py +++ b/pytensor/link/mlx/dispatch/signal/conv.py @@ -21,6 +21,13 @@ def conv1d(raw_data, raw_kernel, runtime_full_mode): data = mlx_typify(raw_data, dtype=None) kernel = mlx_typify(raw_kernel, dtype=None) + # Inside a Blockwise, a kernel that is broadcast across the batch is + # promoted to a leading-1 dim (e.g. ``(1, K)``) instead of being + # vmapped away. ``mx.convolve`` needs 1-D inputs, so flatten the kernel + # back to its core ``(K,)`` shape. + if kernel.ndim > 1: + kernel = kernel.reshape(-1) + if runtime_mode_static: runtime_mode = full_mode else: diff --git a/tests/link/mlx/test_signal_conv.py b/tests/link/mlx/test_signal_conv.py new file mode 100644 index 0000000000..844be2711a --- /dev/null +++ b/tests/link/mlx/test_signal_conv.py @@ -0,0 +1,41 @@ +import numpy as np +import pytest + +import pytensor.tensor as pt +from pytensor.tensor.type import matrix, vector +from tests.link.mlx.test_basic import compare_mlx_and_py + + +pytest.importorskip("mlx.core") + + +@pytest.mark.parametrize("mode", ["full", "valid"]) +def test_convolve1d_vector(mode): + x = vector("x", dtype="float32") + k = vector("k", dtype="float32") + out = pt.signal.conv.convolve1d(x, k, mode=mode) + + rng = np.random.default_rng(0) + x_np = rng.standard_normal(32).astype("float32") + k_np = rng.standard_normal(5).astype("float32") + + compare_mlx_and_py([x, k], [out], [x_np, k_np]) + + +@pytest.mark.parametrize("mode", ["full", "valid"]) +def test_convolve1d_batched_kernel_broadcast(mode): + """A vector kernel shared across a batch of signals is wrapped in a + Blockwise that broadcasts it to a leading-1 dim, so the MLX core thunk + must flatten it back to 1-D before calling ``mx.convolve``. + + Regression test for #2092. + """ + x = matrix("x", dtype="float32") + k = vector("k", dtype="float32") + out = pt.signal.conv.convolve1d(x, k, mode=mode) + + rng = np.random.default_rng(0) + x_np = rng.standard_normal((4, 32)).astype("float32") + k_np = rng.standard_normal(5).astype("float32") + + compare_mlx_and_py([x, k], [out], [x_np, k_np])