diff --git a/skainet-io/skainet-io-safetensors/src/commonMain/kotlin/sk/ainet/io/safetensors/Bf16LoadPolicy.kt b/skainet-io/skainet-io-safetensors/src/commonMain/kotlin/sk/ainet/io/safetensors/Bf16LoadPolicy.kt new file mode 100644 index 00000000..192bf9ab --- /dev/null +++ b/skainet-io/skainet-io-safetensors/src/commonMain/kotlin/sk/ainet/io/safetensors/Bf16LoadPolicy.kt @@ -0,0 +1,56 @@ +package sk.ainet.io.safetensors + +/** + * Controls how the SafeTensors loader handles `BFLOAT16` (BF16) tensors. + * + * BF16 weights are common in modern transformer checkpoints (Gemma-3n, + * many Llama-derivatives shipped from HuggingFace). The decision is + * whether to **dequantise to FP32 at load time** (status quo — every + * BF16 weight doubles its memory footprint and the dequant pass runs + * once per checkpoint load) or to **keep the BF16 bytes native** and + * let the matmul dispatch in `DefaultCpuOps` route to a vectorised + * BF16 kernel. + * + * The two paths produce numerically equivalent results when both reach + * the same matmul kernel — the BF16 → FP32 conversion is the bit-shift + * identity `float_bits = (bf16 & 0xFFFF) shl 16`, applied either at + * load (DEQUANT) or per-multiply (KEEP_NATIVE). + * + * Existing consumers should keep the default. Flip to [KEEP_NATIVE] + * **only** after confirming the runtime dispatch has BF16 support — at + * the time of writing that's SKaiNET-transformers builds against + * SKaiNET develop with the BF16 dispatch wired into + * `DefaultCpuOpsJvm.chooseQuantizedMatmul` (Phase 3 follow-up, + * separately tracked). + */ +public enum class Bf16LoadPolicy { + /** + * Default. Dequantise every BFLOAT16 tensor to FP32 at load time + * via the existing `dequantBF16` helper, then wrap as a + * `FloatArrayTensorData` (same as `FLOAT32` source tensors). + * + * Memory cost: 2× the on-disk size for each BF16 tensor. + * Runtime: zero extra dispatch — every matmul gets FP32 operands. + */ + DEQUANT_TO_FP32, + + /** + * Keep BFLOAT16 tensors in their on-disk packed-2-bytes-per-element + * layout. The loader emits a `Bf16DenseTensorData` (in + * `skainet-lang-core`) instead of dequanting; the tensor still + * advertises FP32 dtype to consumers (the underlying `get` decodes + * on read), but its `tensor.data` is recognisable as + * `Bf16TensorData` so a matmul dispatch can route to the SIMD + * `Bf16MatmulKernel` SPI. + * + * Memory cost: identical to the on-disk size — no doubling. + * Runtime: matmul dispatch picks up the BF16 SPI kernel when one + * is registered; falls back to per-multiply dequant otherwise. + * + * **Caveat**: any non-matmul op that touches the BF16 tensor pays + * a per-element decode cost via `get`. Don't flip this unless the + * model's hot path is dominated by matmuls (the typical + * transformer case). + */ + KEEP_NATIVE, +} diff --git a/skainet-io/skainet-io-safetensors/src/commonMain/kotlin/sk/ainet/io/safetensors/SafeTensorsParametersLoader.kt b/skainet-io/skainet-io-safetensors/src/commonMain/kotlin/sk/ainet/io/safetensors/SafeTensorsParametersLoader.kt index 1aff343c..4fd0b81c 100644 --- a/skainet-io/skainet-io-safetensors/src/commonMain/kotlin/sk/ainet/io/safetensors/SafeTensorsParametersLoader.kt +++ b/skainet-io/skainet-io-safetensors/src/commonMain/kotlin/sk/ainet/io/safetensors/SafeTensorsParametersLoader.kt @@ -6,6 +6,8 @@ import sk.ainet.io.RandomAccessSource import sk.ainet.io.model.DataType import sk.ainet.lang.tensor.Shape import sk.ainet.lang.tensor.Tensor +import sk.ainet.lang.tensor.data.Bf16DenseTensorData +import sk.ainet.lang.tensor.data.TensorData import sk.ainet.lang.types.DType import sk.ainet.lang.types.FP32 import sk.ainet.lang.types.Int32 @@ -23,7 +25,8 @@ import kotlin.reflect.KClass * - F32/F64 tensors -> FP32 (F64 downcast with warning) * - I32/I64 tensors -> Int32 (I64 downcast with warning) * - I8/U8 tensors -> Int8 - * - F16/BF16 tensors -> FP32 (with dequantization) + * - F16 tensors -> FP32 (with dequantization) + * - BF16 tensors -> FP32 (default) OR native BF16 storage (`bf16Policy = KEEP_NATIVE`) * * Where possible, decoded arrays are wrapped (borrowed) rather than copied * into TensorData, avoiding a second allocation. The raw-byte decode step @@ -31,10 +34,16 @@ import kotlin.reflect.KClass * * @param sourceProvider Factory providing RandomAccessSource to the SafeTensors file * @param onProgress Optional progress callback (current, total, tensorName) + * @param bf16Policy How to handle `BFLOAT16` tensors. Default is + * [Bf16LoadPolicy.DEQUANT_TO_FP32] — backward-compatible with all + * existing consumers. Flip to [Bf16LoadPolicy.KEEP_NATIVE] to keep + * weights in their on-disk BF16 layout and let the matmul dispatch + * route to a vectorised BF16 kernel. */ class SafeTensorsParametersLoader( private val sourceProvider: () -> RandomAccessSource, - private val onProgress: (current: Long, total: Long, message: String?) -> Unit = { _, _, _ -> } + private val onProgress: (current: Long, total: Long, message: String?) -> Unit = { _, _, _ -> }, + private val bf16Policy: Bf16LoadPolicy = Bf16LoadPolicy.DEQUANT_TO_FP32, ) : ParametersLoader { override suspend fun load( @@ -82,10 +91,22 @@ class SafeTensorsParametersLoader( DataType.BFLOAT16 -> { require(dtype == FP32::class) { - "SafeTensors BF16 tensor '${tensorInfo.name}' requires FP32 dtype (dequant), got ${dtype.simpleName}" + "SafeTensors BF16 tensor '${tensorInfo.name}' requires FP32 dtype, got ${dtype.simpleName}" + } + when (bf16Policy) { + Bf16LoadPolicy.DEQUANT_TO_FP32 -> { + val floats = dequantBF16(bytes) + ctx.wrapFloatArray(shape, dtype, floats) as Tensor + } + Bf16LoadPolicy.KEEP_NATIVE -> { + // Wrap the on-disk BF16 bytes directly. dtype stays FP32 from + // the consumer's POV (Bf16TensorData : TensorData + // decodes on read); the storage type is what the matmul + // dispatch will pattern-match on to pick the BF16 SPI kernel. + val bf16Data = Bf16DenseTensorData(shape, bytes) + ctx.fromData(bf16Data as TensorData, dtype) + } } - val floats = dequantBF16(bytes) - ctx.wrapFloatArray(shape, dtype, floats) as Tensor } DataType.INT32 -> { diff --git a/skainet-io/skainet-io-safetensors/src/jvmTest/kotlin/sk/ainet/io/safetensors/SafeTensorsParametersLoaderBf16PolicyTest.kt b/skainet-io/skainet-io-safetensors/src/jvmTest/kotlin/sk/ainet/io/safetensors/SafeTensorsParametersLoaderBf16PolicyTest.kt new file mode 100644 index 00000000..a42436a4 --- /dev/null +++ b/skainet-io/skainet-io-safetensors/src/jvmTest/kotlin/sk/ainet/io/safetensors/SafeTensorsParametersLoaderBf16PolicyTest.kt @@ -0,0 +1,220 @@ +package sk.ainet.io.safetensors + +import kotlinx.coroutines.runBlocking +import org.junit.Test +import sk.ainet.context.DirectCpuExecutionContext +import sk.ainet.io.JvmRandomAccessSource +import sk.ainet.lang.tensor.Tensor +import sk.ainet.lang.tensor.data.Bf16DenseTensorData +import sk.ainet.lang.tensor.data.Bf16TensorData +import sk.ainet.lang.tensor.data.FloatArrayTensorData +import sk.ainet.lang.types.FP32 +import java.io.File +import java.nio.ByteBuffer +import java.nio.ByteOrder +import java.nio.file.Files +import kotlin.math.abs +import kotlin.test.assertContentEquals +import kotlin.test.assertEquals +import kotlin.test.assertTrue + +/** + * Tests for [SafeTensorsParametersLoader]'s BF16-handling under both + * [Bf16LoadPolicy] options. + * + * The default `DEQUANT_TO_FP32` policy preserves existing behavior — BF16 + * tensors come out as `FloatArrayTensorData` with dequanted values. The + * `KEEP_NATIVE` policy emits `Bf16DenseTensorData` with the on-disk bytes + * intact; element access still produces FP32 values, but storage is half + * the size and recognizable by the matmul dispatch. + */ +class SafeTensorsParametersLoaderBf16PolicyTest { + + /** BF16 absolute tolerance — 7 bit mantissa ≈ 1/128. */ + private val bf16AbsTol = 1e-2f + + /** Truncate FP32 to BF16 bits (high 16 bits, zero rounding). */ + private fun fp32ToBf16Bytes(values: FloatArray): ByteArray { + val out = ByteArray(values.size * 2) + for (i in values.indices) { + val bf16 = (values[i].toRawBits() ushr 16) and 0xFFFF + out[i * 2] = (bf16 and 0xFF).toByte() // lo + out[i * 2 + 1] = ((bf16 ushr 8) and 0xFF).toByte() // hi + } + return out + } + + /** Build a single-tensor BF16 SafeTensors file on disk; return the temp File. */ + private fun createBF16SafeTensorsFile( + name: String, + shape: List, + values: FloatArray, + ): File { + val bytes = fp32ToBf16Bytes(values) + val shapeStr = shape.joinToString(", ") + val headerJson = + "{\"$name\": {\"dtype\": \"BF16\", \"shape\": [$shapeStr], \"data_offsets\": [0, ${bytes.size}]}}" + val headerBytes = headerJson.toByteArray(Charsets.UTF_8) + val tempFile = Files.createTempFile("test_bf16_safetensors", ".safetensors").toFile() + tempFile.deleteOnExit() + tempFile.outputStream().use { out -> + val headerSizeBytes = ByteBuffer.allocate(8) + .order(ByteOrder.LITTLE_ENDIAN) + .putLong(headerBytes.size.toLong()) + .array() + out.write(headerSizeBytes) + out.write(headerBytes) + out.write(bytes) + } + return tempFile + } + + /** Build a mixed FP32 + BF16 SafeTensors file — used for the mixed-dtype test. */ + private fun createMixedSafeTensorsFile( + bf16Values: FloatArray, + fp32Values: FloatArray, + ): File { + val bf16Bytes = fp32ToBf16Bytes(bf16Values) + val fp32Buf = ByteBuffer.allocate(fp32Values.size * 4).order(ByteOrder.LITTLE_ENDIAN) + fp32Values.forEach { fp32Buf.putFloat(it) } + val fp32Bytes = fp32Buf.array() + val bf16End = bf16Bytes.size.toLong() + val fp32End = bf16End + fp32Bytes.size + val headerJson = + "{\"weight_bf16\": {\"dtype\": \"BF16\", \"shape\": [${bf16Values.size}], \"data_offsets\": [0, $bf16End]}," + + "\"weight_fp32\": {\"dtype\": \"F32\", \"shape\": [${fp32Values.size}], \"data_offsets\": [$bf16End, $fp32End]}}" + val headerBytes = headerJson.toByteArray(Charsets.UTF_8) + val tempFile = Files.createTempFile("test_mixed_safetensors", ".safetensors").toFile() + tempFile.deleteOnExit() + tempFile.outputStream().use { out -> + val headerSizeBytes = ByteBuffer.allocate(8) + .order(ByteOrder.LITTLE_ENDIAN) + .putLong(headerBytes.size.toLong()) + .array() + out.write(headerSizeBytes) + out.write(headerBytes) + out.write(bf16Bytes) + out.write(fp32Bytes) + } + return tempFile + } + + private fun loadAll( + file: File, + policy: Bf16LoadPolicy, + ): Map> = runBlocking { + val ctx = DirectCpuExecutionContext.create() + val loader = SafeTensorsParametersLoader( + sourceProvider = { JvmRandomAccessSource.open(file) }, + bf16Policy = policy, + ) + val out = mutableMapOf>() + loader.load(ctx, FP32::class) { name, tensor -> + out[name] = tensor + } + out + } + + @Test + fun bf16_default_policy_dequants_to_fp32_floatArray() { + val values = floatArrayOf(0.0f, 1.0f, -1.0f, 0.5f, 3.0f, -2.5f) + val file = createBF16SafeTensorsFile("weight", listOf(values.size.toLong()), values) + + // Default policy = DEQUANT_TO_FP32. + val tensors = loadAll(file, Bf16LoadPolicy.DEQUANT_TO_FP32) + val weight = tensors["weight"] ?: error("missing 'weight'") + + assertTrue( + weight.data is FloatArrayTensorData<*>, + "DEQUANT_TO_FP32 policy must produce FloatArrayTensorData, got ${weight.data::class.simpleName}", + ) + + val out = weight.data.copyToFloatArray() + assertEquals(values.size, out.size) + for (i in values.indices) { + assertTrue( + abs(values[i] - out[i]) <= bf16AbsTol, + "value mismatch at $i: in=${values[i]} out=${out[i]}", + ) + } + } + + @Test + fun bf16_keep_native_policy_emits_bf16DenseTensorData() { + val values = floatArrayOf(0.0f, 1.0f, -1.0f, 0.5f, 3.0f, -2.5f, 100.0f, -64.0f) + val file = createBF16SafeTensorsFile("weight", listOf(values.size.toLong()), values) + + val tensors = loadAll(file, Bf16LoadPolicy.KEEP_NATIVE) + val weight = tensors["weight"] ?: error("missing 'weight'") + + assertTrue( + weight.data is Bf16TensorData, + "KEEP_NATIVE policy must produce Bf16TensorData, got ${weight.data::class.simpleName}", + ) + assertTrue( + weight.data is Bf16DenseTensorData, + "KEEP_NATIVE policy must use Bf16DenseTensorData (the only Bf16TensorData impl today)", + ) + + // packedData should be exactly the BF16 bytes we wrote — verbatim, no dequant pass. + val expectedBytes = fp32ToBf16Bytes(values) + val packed = (weight.data as Bf16DenseTensorData).packedData + assertContentEquals( + expectedBytes, packed, + "KEEP_NATIVE policy must preserve on-disk BF16 bytes byte-for-byte", + ) + + // get() decodes to FP32 — values should match the FP32 input within BF16 precision. + val decoded = weight.data.copyToFloatArray() + assertEquals(values.size, decoded.size) + for (i in values.indices) { + assertTrue( + abs(values[i] - decoded[i]) <= bf16AbsTol, + "BF16 decode mismatch at $i: in=${values[i]} decoded=${decoded[i]}", + ) + } + } + + @Test + fun bf16_keep_native_decoded_values_match_dequant_path_exactly() { + // Both policies apply the same `bf16_bits << 16` math; the only + // difference is WHEN. Decoded values should be bit-identical. + val values = FloatArray(64) { (it - 32) * 0.25f } + val file = createBF16SafeTensorsFile("w", listOf(values.size.toLong()), values) + + val dequant = loadAll(file, Bf16LoadPolicy.DEQUANT_TO_FP32)["w"]!!.data.copyToFloatArray() + val native = loadAll(file, Bf16LoadPolicy.KEEP_NATIVE)["w"]!!.data.copyToFloatArray() + + assertEquals(dequant.size, native.size) + for (i in dequant.indices) { + assertEquals( + dequant[i].toRawBits(), native[i].toRawBits(), + "bit-identity expected at $i: dequant=${dequant[i]} native=${native[i]}", + ) + } + } + + @Test + fun mixed_bf16_fp32_file_keep_native_only_affects_bf16() { + val bf16Values = floatArrayOf(1.0f, 2.0f, 4.0f, 8.0f) + val fp32Values = floatArrayOf(0.1f, 0.2f, 0.4f, 0.8f, 1.6f) + val file = createMixedSafeTensorsFile(bf16Values, fp32Values) + + val tensors = loadAll(file, Bf16LoadPolicy.KEEP_NATIVE) + val bf16 = tensors["weight_bf16"] ?: error("missing 'weight_bf16'") + val fp32 = tensors["weight_fp32"] ?: error("missing 'weight_fp32'") + + assertTrue( + bf16.data is Bf16TensorData, + "BF16 tensor must use Bf16TensorData under KEEP_NATIVE, got ${bf16.data::class.simpleName}", + ) + assertTrue( + fp32.data is FloatArrayTensorData<*>, + "FP32 tensor must stay FloatArrayTensorData under KEEP_NATIVE, got ${fp32.data::class.simpleName}", + ) + + // FP32 values should round-trip exactly. + val fp32Out = fp32.data.copyToFloatArray() + assertContentEquals(fp32Values, fp32Out, "FP32 tensor values must be unchanged") + } +}