diff --git a/python/infinicore/ops/paged_attention_prefill.py b/python/infinicore/ops/paged_attention_prefill.py index 848f74abf..259e865ae 100644 --- a/python/infinicore/ops/paged_attention_prefill.py +++ b/python/infinicore/ops/paged_attention_prefill.py @@ -2,6 +2,12 @@ from infinicore.tensor import Tensor +def _ensure_head_dim_contiguous(tensor: Tensor) -> Tensor: + if tensor.ndim > 0 and tensor.stride(tensor.ndim - 1) != 1: + return tensor.contiguous() + return tensor + + def paged_attention_prefill( q: Tensor, k_cache: Tensor, @@ -14,6 +20,8 @@ def paged_attention_prefill( *, out: Tensor | None = None, ): + k_cache = _ensure_head_dim_contiguous(k_cache) + v_cache = _ensure_head_dim_contiguous(v_cache) alibi_ptr = alibi_slopes._underlying if alibi_slopes is not None else None if out is None: