From a0c98d0896e397799213d808bd57bba908198ae1 Mon Sep 17 00:00:00 2001 From: Cyril362005 Date: Mon, 4 May 2026 13:19:12 +0530 Subject: [PATCH] ENH: delegate broadcast_shapes --- src/array_api_extra/__init__.py | 2 +- src/array_api_extra/_delegation.py | 64 +++++++++++++++++++++++++++++- src/array_api_extra/_lib/_funcs.py | 44 ++------------------ tests/test_funcs.py | 35 ++++++++++++++++ 4 files changed, 103 insertions(+), 42 deletions(-) diff --git a/src/array_api_extra/__init__.py b/src/array_api_extra/__init__.py index 2fcdcd8e..4efe1705 100644 --- a/src/array_api_extra/__init__.py +++ b/src/array_api_extra/__init__.py @@ -3,6 +3,7 @@ from ._delegation import ( argpartition, atleast_nd, + broadcast_shapes, cov, create_diagonal, expand_dims, @@ -20,7 +21,6 @@ from ._lib._at import at from ._lib._funcs import ( apply_where, - broadcast_shapes, default_dtype, kron, nunique, diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index 46639559..ee3d4123 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -2,7 +2,7 @@ from collections.abc import Sequence from types import ModuleType -from typing import Literal +from typing import Literal, cast from ._lib import _funcs from ._lib._utils._compat import ( @@ -20,6 +20,7 @@ __all__ = [ "atleast_nd", + "broadcast_shapes", "cov", "create_diagonal", "expand_dims", @@ -81,6 +82,67 @@ def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array return _funcs.atleast_nd(x, ndim=ndim, xp=xp) +def broadcast_shapes( + *shapes: tuple[float | None, ...], xp: ModuleType | None = None +) -> tuple[int | None, ...]: + """ + Compute the shape of the broadcasted arrays. + + Duplicates :func:`numpy.broadcast_shapes`, with additional support for + None and NaN sizes. + + This is equivalent to ``xp.broadcast_arrays(arr1, arr2, ...)[0].shape`` + without needing to worry about the backend potentially deep copying + the arrays. + + Parameters + ---------- + *shapes : tuple[int | None, ...] + Shapes of the arrays to broadcast. + xp : array_namespace, optional + The standard-compatible namespace to use for native delegation. + Default: use the array-agnostic implementation. + + Returns + ------- + tuple[int | None, ...] + The shape of the broadcasted arrays. + + See Also + -------- + numpy.broadcast_shapes : Equivalent NumPy function. + array_api.broadcast_arrays : Function to broadcast actual arrays. + + Notes + ----- + This function accepts the Array API's ``None`` for unknown sizes, + as well as Dask's non-standard ``math.nan``. + Regardless of input, the output always contains ``None`` for unknown sizes. + + Examples + -------- + >>> import array_api_extra as xpx + >>> xpx.broadcast_shapes((2, 3), (2, 1)) + (2, 3) + >>> xpx.broadcast_shapes((4, 2, 3), (2, 1), (1, 3)) + (4, 2, 3) + """ + if ( + xp is not None + and all(isinstance(size, int) for shape in shapes for size in shape) + and ( + is_numpy_namespace(xp) + or is_cupy_namespace(xp) + or is_jax_namespace(xp) + or is_torch_namespace(xp) + ) + ): + int_shapes = cast(tuple[tuple[int, ...], ...], shapes) + return cast(tuple[int | None, ...], xp.broadcast_shapes(*int_shapes)) + + return _funcs.broadcast_shapes(*shapes) + + def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array: """ Estimate a covariance matrix (or a stack of covariance matrices). diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index 97904ddb..0d7ad988 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -220,46 +220,10 @@ def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType) -> Array: # `float` in signature to accept `math.nan` for Dask. # `int`s are still accepted as `float` is a superclass of `int` in typing -def broadcast_shapes(*shapes: tuple[float | None, ...]) -> tuple[int | None, ...]: - """ - Compute the shape of the broadcasted arrays. - - Duplicates :func:`numpy.broadcast_shapes`, with additional support for - None and NaN sizes. - - This is equivalent to ``xp.broadcast_arrays(arr1, arr2, ...)[0].shape`` - without needing to worry about the backend potentially deep copying - the arrays. - - Parameters - ---------- - *shapes : tuple[int | None, ...] - Shapes of the arrays to broadcast. - - Returns - ------- - tuple[int | None, ...] - The shape of the broadcasted arrays. - - See Also - -------- - numpy.broadcast_shapes : Equivalent NumPy function. - array_api.broadcast_arrays : Function to broadcast actual arrays. - - Notes - ----- - This function accepts the Array API's ``None`` for unknown sizes, - as well as Dask's non-standard ``math.nan``. - Regardless of input, the output always contains ``None`` for unknown sizes. - - Examples - -------- - >>> import array_api_extra as xpx - >>> xpx.broadcast_shapes((2, 3), (2, 1)) - (2, 3) - >>> xpx.broadcast_shapes((4, 2, 3), (2, 1), (1, 3)) - (4, 2, 3) - """ +def broadcast_shapes( # numpydoc ignore=PR01,RT01 + *shapes: tuple[float | None, ...], +) -> tuple[int | None, ...]: + """See docstring in array_api_extra._delegation.""" if not shapes: return () # Match NumPy output diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 6a11e059..af36f642 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -489,6 +489,41 @@ def test_5D_values(self, xp: ModuleType): class TestBroadcastShapes: + def test_delegates_known_integer_shapes(self, monkeypatch: pytest.MonkeyPatch): + calls = [] + + def mock_broadcast_shapes(*shapes: tuple[int, ...]) -> tuple[int, ...]: + calls.append(shapes) + return (99,) + + monkeypatch.setattr(np, "broadcast_shapes", mock_broadcast_shapes) + + assert broadcast_shapes((2,), (1,), xp=np) == (99,) + assert calls == [((2,), (1,))] + + def test_fallback_for_unknown_sizes(self, monkeypatch: pytest.MonkeyPatch): + def mock_broadcast_shapes(*_shapes: tuple[int, ...]) -> tuple[int, ...]: + msg = "Native delegation should not handle unknown sizes" + raise AssertionError(msg) + + monkeypatch.setattr(np, "broadcast_shapes", mock_broadcast_shapes) + + assert broadcast_shapes((None,), (1,), xp=np) == (None,) + assert broadcast_shapes((math.nan,), (1,), xp=np) == (None,) + + def test_fallback_without_xp(self, monkeypatch: pytest.MonkeyPatch): + def mock_broadcast_shapes(*_shapes: tuple[int, ...]) -> tuple[int, ...]: + msg = "Native delegation should not be used without xp" + raise AssertionError(msg) + + monkeypatch.setattr(np, "broadcast_shapes", mock_broadcast_shapes) + + assert broadcast_shapes((2,), (1,)) == (2,) + + @pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY, reason="xp=xp") + def test_xp(self, xp: ModuleType): + assert broadcast_shapes((2, 3), (2, 1), xp=xp) == (2, 3) + @pytest.mark.parametrize( "args", [