Skip to content

BF16 dispatch chain (Phase 3/3): wire Bf16TensorData dispatch in DefaultCpuOpsJvm.chooseQuantizedMatmul #613

@michalharakal

Description

@michalharakal

Phase 3 — final wiring of the BF16 dispatch chain. Follow-ups: #610 (Bf16TensorData merged) and #612 (loader KEEP_NATIVE policy merged).

What's missing

After #610 + #612, BF16 weights loaded via SafeTensorsParametersLoader with bf16Policy = KEEP_NATIVE arrive as Bf16DenseTensorData-backed tensors. But DefaultCpuOps.matmul doesn't recognise this storage type yet — chooseQuantizedMatmul only matches Q4_K, Q6_K, Q8_0, and MemorySegment-backed quantised data. A BF16-tensor matmul falls through to super.matmul, which scalar-loops through the FP32 path (decoding BF16 per-element via Bf16DenseTensorData.get). The SIMD BF16 kernel from #605 stays unreachable.

This PR adds the is Bf16TensorData -> branch — the final piece of the chain.

Scope

  1. Add private val bf16MatmulKernel: Bf16MatmulKernel by lazy { ... } in DefaultCpuOpsJvm. Mirrors fp32MatmulKernel (non-null with scalar fallback) rather than q8_0MatmulKernel (nullable with legacy fallback) because there's no legacy non-SPI BF16 kernel — the scalar SPI provider is the floor.
  2. Add is Bf16TensorData -> branch in chooseQuantizedMatmul's when (bData) block. Calls the SPI's matmul with dense (m, n, k) parameters — the BF16 kernel is a full 2D SGEMM, not a matvec like Q4_K/Q8_0, so no per-batch loop is needed.

Tests

New Bf16MatmulDispatchTest in commonTest (mirrors Q8_0MatmulDispatchTest):

  • single_batch_matmul_against_bf16_weight_routes_correctly[1, k] × [k, n] matmul against a Bf16DenseTensorData weight; compares against ScalarBf16MatmulKernel reference within 1e-2 * k tolerance.
  • multi_batch_matmul_against_bf16_weight_routes_correctly[m, k] × [k, n] with m>1.
  • llm_typical_attention_proj_matmul_routes_correctly — 512² shape.

Out of scope (separate follow-ups)

  • Gemma-3n end-to-end smoke test (probably in SKaiNET-transformers/llm-inference/gemma — exercises the full chain on a real model).
  • Phase B2 hand-tuned NEON intrinsics for the Panama BF16 kernel (the bench in BF16 matmul: add Bf16MatmulKernel + scalar/Panama/native implementations #605 showed Panama bf16 at ~1× scalar because of scratch-dequant; in-SIMD widening would lift that to 5-10×).

Branch

feature/bf16-dispatch off develop.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions