diff --git a/docs/modules/ROOT/nav.adoc b/docs/modules/ROOT/nav.adoc index e4baba19..d264bdf9 100644 --- a/docs/modules/ROOT/nav.adoc +++ b/docs/modules/ROOT/nav.adoc @@ -32,6 +32,7 @@ .Contributing * xref:contributing/index.adoc[Audience and scope] * xref:contributing/build-from-source.adoc[Build from source] +* xref:contributing/dtype-model.adoc[The SKaiNET dtype model] * xref:contributing/benchmarks.adoc[Engine benchmark program] * xref:contributing/matmul-kernels.adoc[Reading the matmul benchmark] * xref:contributing/register-bench-runner.adoc[Register a self-hosted bench runner] diff --git a/docs/modules/ROOT/pages/contributing/dtype-model.adoc b/docs/modules/ROOT/pages/contributing/dtype-model.adoc new file mode 100644 index 00000000..ca72bd08 --- /dev/null +++ b/docs/modules/ROOT/pages/contributing/dtype-model.adoc @@ -0,0 +1,139 @@ += The SKaiNET DType Model +:description: How SKaiNET represents tensor dtypes across loaders, the kernel SPI, and the (in-progress) constraint-resolution pipeline — mapped onto the four-dtype concept from the dtype-policy RFC (#615). + +[NOTE] +==== +**Audience: SKaiNET maintainers and contributors.** This page maps +the vocabulary used in the +https://github.com/SKaiNET-developers/SKaiNET/blob/develop/rfc.md[dtype-policy +RFC] (issue #615) onto the existing SKaiNET implementations. +Library consumers don't need to read this — they call +`tensor(ctx, FP32::class) { … }` and the engine does +the rest. +==== + +The RFC distinguishes four dtype concepts; the engine mostly already +implements them, but under different names. This page is the +glossary that keeps the two consistent. + +== The four dtype concepts + +[cols="1,2,2,3",options="header"] +|=== +| RFC term | What it means | SKaiNET implementation today | Notes +| **source dtype** | The dtype stored in the on-disk model file (`F16`, `F32`, `Q4_K`, `Q8_0`, …). | Set by the loader. `StreamingGgufParametersLoader` maps `GGMLQuantizationType.*` to the corresponding `TensorData` subtype. `SafeTensorsParametersLoader` maps SafeTensors `DataType` similarly. | The loader-time mapping is the **source-of-truth** for what the file actually contains. +| **logical dtype** | The dtype the tensor advertises to graph code (op contracts, shape inference, dispatch). | `Tensor.dtype: KClass` — the type-parameter `T` resolves to one of the sealed `DType` arms (`FP32`, `BF16`, `Int8`, …). | The logical dtype is **never** inferred from physical storage shape (no "1D byte array patched into 2D" antipattern — every quantized `TensorData` subtype carries explicit `shape: Shape`). +| **required dtype** | The dtype an op, layer, or backend declares it needs. | Today: implicit in the kernel SPI accessors (`matmulFp32()`, `matmulBf16()`, `matmulQ4K()`, `matmulQ8_0()`). After W6/W7 of #615: explicit `DTypePolicy` attached to graph nodes via `attributes["dtype_policy"]`. | The `DTypePolicy` sealed type (W1, shipped in this PR series) covers the four arms from the RFC's "policy categories" section: `Any` / `Require` / `Prefer` / `OneOf`. +| **lowered dtype** | The dtype actually passed to the executable kernel. | Whatever `KernelRegistry.bestAvailable()?.matmul*()` returns. `KernelProvider.supports(opName, dtypeKeys)` (W3, shipped) is the introspection query. | If a `Require` constraint can't be matched by any registered kernel and no cast kernel bridges the gap, the constraint-resolution pass (W7, pending) raises `DtypeConstraintViolationException` *before* forward execution — exactly the RFC's "fail before execution" rule. +|=== + +== Loader source → logical mapping today + +Both loaders are explicit about what each on-disk dtype becomes +inside the engine. This table is the W0a audit promised by issue +#615 — it makes the silent dequant cases visible so the loader-policy +work (W0b / W0c) knows what to generalise. + +=== `StreamingGgufParametersLoader` (skainet-io-gguf) + +[cols="1m,1,2,2",options="header"] +|=== +| GGUF source type | Logical dtype today | Storage class | Native or dequant? +| F32 | `FP32` | `FloatArrayTensorData` (dense) | native +| I32 | `Int32` | `IntArrayTensorData` (dense) | native +| F16 | `FP32` | `FloatArrayTensorData` (dense, dequanted) | **dequant on load** — no `KEEP_NATIVE` path yet +| BF16 | `FP32` | `FloatArrayTensorData` (dense, dequanted) | **dequant on load** — no `KEEP_NATIVE` path yet +| Q4_K | `FP32`-tagged tensor wrapping `Q4_KBlockTensorData` | `Q4_KBlockTensorData` (packed, logical shape preserved) | native +| Q8_0 | `FP32`-tagged tensor wrapping `Q8_0BlockTensorData` | `Q8_0BlockTensorData` (packed, logical shape preserved) | native +|=== + +The two dequant rows (F16, BF16) are the gap. SafeTensors already +has a `Bf16LoadPolicy.KEEP_NATIVE` opt-in (see below) that returns +the BF16 bytes verbatim instead of expanding to FP32. The +equivalent for GGUF is W0c (`StreamingGgufParametersLoader.loadWithPolicy`). + +=== `SafeTensorsParametersLoader` (skainet-io-safetensors) + +[cols="1m,1,2,2",options="header"] +|=== +| SafeTensors source type | Logical dtype today | Storage class | Native or dequant? +| F32 / F64 | `FP32` | `FloatArrayTensorData` | native (F64 down-cast with warning) +| F16 | `FP32` | `FloatArrayTensorData` (dequanted) | **dequant on load** — no `KEEP_NATIVE` path yet +| BF16 | `FP32` or `BF16`-shaped depending on `Bf16LoadPolicy` | `FloatArrayTensorData` (dequanted) or `Bf16DenseTensorData` (native) | **policy-controlled**: `DEQUANT_TO_FP32` (default) or `KEEP_NATIVE` +| I32 / I16 / U16 / U32 / U64 / I8 / U8 | matching `Int*` / `UInt*` | wrapped / reinterpreted appropriately | native +|=== + +The BF16 row is the prior art for the RFC's policy model. `Bf16LoadPolicy.toDTypePolicy()` (W2, shipped) maps the BF16-specific enum onto the generalised `DTypePolicy`: + +[source,kotlin] +---- +Bf16LoadPolicy.DEQUANT_TO_FP32.toDTypePolicy() // → DTypePolicy.Require(FP32) +Bf16LoadPolicy.KEEP_NATIVE.toDTypePolicy() // → DTypePolicy.Require(BF16) +---- + +W0b extends this same idea to F16 and the integer dtypes so the +whole SafeTensors loader can be driven by a single `DTypePolicy` +argument. + +== The `DType` registry vs the kernel capability query + +`DType.findByName("Float32")` returns the singleton `FP32` object — +the sealed-interface registry is the source-of-truth for dtype +metadata (size in bits, name, promotion rules). It currently covers +floats and (un)signed integers from `Ternary` through `FP64`. + +The quantized block formats (`Q4_K`, `Q8_0`, `Q6_K`, `Q4_0`, …) +are **not** `DType` arms — they live as `TensorData` subtypes in +`skainet-lang-core/tensor/data/`. That's intentional: a `DType` is +a numeric type with promotion semantics, whereas Q4_K is a *packed +block format* with no scalar interpretation outside its block +context. + +For the kernel capability query (`KernelProvider.supports(opName, +dtypeKeys)`, W3), this means the second argument is `List` +rather than `List` — the strings `"Q4_K"` and `"Q8_0"` slot +in alongside `"Float32"` and `"BFloat16"`. The string convention +matches what GGUF / SafeTensors loaders and the StableHLO converter +already use for format identification. + +== Fail-fast: `KernelStrictness` + +The RFC's "fail before execution" rule has a small, ready +affordance today (W4, shipped): + +[source,bash] +---- +java -Dskainet.strict.kernels=true … +---- + +When set, `DefaultCpuOpsJvm.matmul` raises +`NoSuchKernelException` (with the failing dtype pair and the list +of currently-registered providers) just before its silent scalar +fallback would have run. Default off — adaptive behaviour is +preserved. + +The constraint-resolution pass (W7) raises the same exception +shape at *graph-prep* time, before forward execution can even +start. The `KernelStrictness` affordance is the runtime equivalent +for cases where graph prep hasn't been run (e.g. ad-hoc tensor-op +code that calls `ctx.ops.matmul` directly). + +== Anti-patterns this model prevents + +The RFC calls out three concrete anti-patterns the engine must +avoid; SKaiNET already prevents all three. + +[cols="2,3",options="header"] +|=== +| Anti-pattern | What prevents it in SKaiNET today +| Marker-class dtype detection (`if tensor is Q4_KMarker`) | The sealed `DType` interface carries explicit metadata (`sizeInBits`, `name`, `isCompatible`, `promoteTo`). Dispatch uses `KClass` identity and the typed accessors on `KernelProvider`, not marker checks. +| Packed bytes treated as logical shape (1D byte array patched into 2D after load) | Every quantized `TensorData` subtype (`Q4_KBlockTensorData`, `Q8_0BlockTensorData`, `Bf16DenseTensorData`) carries an explicit `shape: Shape` separate from its `packedData: ByteArray`. Loaders set the logical shape from the file header, not from `bytes.size`. +| GGUF Q8 confused with native int8 | They're different `TensorData` subtypes. A GGUF Q8 tensor goes through `Q8_0BlockTensorData` (with FP16 scale + 32 signed int8 codes per block); a future native-int8 NPU tensor would have its own `TensorData` subtype with backend-specific layout metadata. The RFC's "GGUF Q8 ≠ native int8" rule is enforced structurally. +|=== + +== Related + +* `rfc.md` (repo root) — the design document this page implements. +* Issue https://github.com/SKaiNET-developers/SKaiNET/issues/615[#615] — implementation tracker. +* xref:contributing/benchmarks.adoc[Engine benchmark program] — runtime numbers that the kernel SPI produces. +* xref:contributing/matmul-kernels.adoc[Reading the matmul benchmark] — how the kernel SPI's dispatch actually shows up in measurements. diff --git a/rfc.md b/rfc.md new file mode 100644 index 00000000..d8f82d24 --- /dev/null +++ b/rfc.md @@ -0,0 +1,626 @@ +# RFC: Hybrid Adaptive DSL with Optional DType Constraints + +**Status:** Draft + +Summary +This RFC proposes a hybrid adaptive DSL for model definition and execution. + +The DSL remains architecture-first by default: it describes layer topology, tensor roles, and graph structure without requiring a fixed dtype for every tensor. Tensor dtype normally follows the loaded model file. + +At the same time, the DSL may optionally express explicit dtype constraints where execution requires them. These constraints are resolved during load, compile, or lowering, before forward execution begins. + +This provides two important properties: + +A single DSL definition can load different GGUF quantization variants. +Strict execution targets, such as NPUs, can require specific runtime dtypes and layouts. +The key rule is: + +DType annotations in the DSL describe executable requirements, not assumptions about the source file. + +Motivation +GGUF models frequently use heterogeneous per-tensor quantization. A single file may contain tensors in FP16, FP32, Q8, Q4, Q4_K, or other quantized formats. + +A strict DSL that hardcodes dtype into every layer has several drawbacks: + +one model architecture may require multiple DSL definitions for different quant variants +mixed-precision GGUF files become awkward to represent +loading arbitrary GGUF variants becomes harder +dtype policy becomes coupled to model architecture +conversion may be forced even when the current backend could execute the source dtype directly +An adaptive DSL solves this by allowing tensor dtype to follow the file. However, pure adaptivity is not enough for constrained execution targets. + +For example, an NPU may support native int8 execution but not GGUF Q8, Q4_K, or other packed quantized formats. In that case, the DSL or backend configuration must be able to require a specific executable dtype or layout. + +This RFC proposes combining both approaches: + +adaptive dtype behavior by default +explicit dtype constraints when needed +load/compile-time constraint resolution +backend-specific lowering before execution +Goals +Keep the DSL architecture-focused by default. +Allow one DSL definition to load multiple quantized model variants. +Support mixed-precision GGUF files. +Make dtype a first-class tensor property. +Allow explicit dtype constraints for specific ops, tensors, layers, or backends. +Resolve hard dtype requirements before forward execution. +Avoid marker-class-based dtype detection. +Avoid treating raw packed byte shape as logical tensor shape. +Separate source file dtype from executable backend dtype. +Support restricted backends such as NPUs without making the whole DSL strict. +Produce a dtype-safe prepared DAG before forward execution. +Support an optional compiled/lowered path for StableHLO, MLIR, or native optimized code. +Non-Goals +This RFC does not define a new tensor engine. +This RFC does not prescribe a specific Kotlin API. +This RFC does not require all tensors to be converted at load time. +This RFC does not require the DSL to declare every tensor dtype. +This RFC does not define exact quantization algorithms. +This RFC does not define backend-specific packed layouts. +This RFC does not require GGUF Q8 to be treated as native int8. +Definitions +Source dtype +The dtype stored in the model file. + +Examples: + +FP16 +FP32 +Q8 +Q4 +Q4_K +The source dtype describes what was read from disk. + +Logical dtype +The dtype represented by the tensor inside the engine. + +This should be explicit tensor metadata, not inferred from wrapper classes or raw storage type. + +Required dtype +The dtype required by an op, layer, backend, or execution policy. + +For example, an NPU backend may require int8 tensors for a given matrix multiplication. + +Lowered dtype +The dtype and layout actually passed to the executable kernel. + +This may differ from the source dtype if conversion or lowering occurred. + +Logical shape +The shape of the tensor as seen by the graph. + +For example, a quantized matrix may logically be: + +[out_features, in_features] +even if it is physically stored as packed bytes. + +Physical storage layout +The internal memory representation of a tensor. + +For quantized tensors, this may include: + +packed bytes +block structure +scales +zero points +backend-specific layout metadata +Physical storage layout is an implementation detail of the tensor representation. + +Resolved DAG +A normalized internal graph produced from the DSL and loaded tensors. + +The resolved DAG makes execution metadata explicit on nodes and edges, including: + +tensor logical shapes +resolved dtypes +layouts +backend assignments +conversion nodes +lowering nodes +op dependencies +quantization metadata +dtype and backend constraints +The resolved DAG is a compiled intermediate representation, not necessarily the final executable artifact. + +Executable plan +A scheduled and backend-aware representation derived from the resolved DAG. + +The executable plan includes selected kernels, memory planning, buffer reuse, constant placement, lowered tensors, and backend-specific execution decisions. + +Lowering +The process of converting high-level graph operations, tensor dtypes, layouts, or storage formats into representations required by a selected backend. + +Examples include: + +Q4_K weight to native int8 NPU weight +GGUF packed layout to backend-native layout +high-level projection op to backend-specific matmul op +dynamic dtype choice to fixed kernel selection +resolved DAG to StableHLO, MLIR, or native backend code +Lowering is part of graph preparation. It may happen during loading if the target backend is already known, or during an explicit compile step if backend selection happens later. + +Design Overview +The DSL defines model architecture and optional dtype constraints. + +The model file provides source tensors with source dtypes and logical shapes. + +The loader creates engine tensors with explicit dtype metadata. + +The compile or lowering phase resolves constraints against backend capabilities. + +If a hard constraint can be satisfied, tensors may be converted or lowered. If it cannot be satisfied, loading or compilation fails. + +Forward execution only sees resolved tensors. + +flowchart TD + A[DSL definition] --> B[Model architecture] + A --> C[Optional dtype constraints] + + D[Model file / GGUF] --> E[Source tensors with file dtypes] + + B --> F[Graph construction] + C --> G[Constraint resolution] + E --> G + + G --> H{Constraints satisfied?} + + H -- Yes, as-is --> I[Use source dtype directly] + H -- Requires conversion --> J[Lower / convert tensor] + H -- Impossible --> K[Fail at load or compile time] + + I --> L[Resolved runtime tensors] + J --> L + + L --> M[Kernel dispatch] + M --> N[Execution on CPU / SIMD / NPU] +Default Adaptive Behavior +If no dtype constraint is declared, the engine should preserve the dtype provided by the model file whenever possible. + +For example: + +GGUF tensor: Q4_K +DSL constraint: none +Backend: CPU +Result: keep Q4_K and dispatch Q4_K-capable kernel +This allows one DSL definition to support many model variants. + +flowchart LR + A[GGUF Q4/Q8/FP16/FP32 tensor] --> B[Engine tensor with explicit dtype] + B --> C[No hard dtype constraint] + C --> D[Keep source dtype] + D --> E[Dispatch by actual tensor dtype] +Explicit DType Constraints +The DSL may optionally declare that a tensor or op requires a specific dtype. + +Such annotations should be interpreted as execution constraints. + +They do not mean the source file must already contain that dtype. + +For example: + +Source tensor: Q4_K +Required dtype: int8 +Backend: NPU +Resolution: lower Q4_K to backend-native int8, or fail +This allows restricted targets to express requirements without making the entire DSL strict. + +flowchart TD + A[Tensor loaded from file] --> B{Does DSL/backend require a specific dtype?} + + B -- No --> C[Keep file dtype] + B -- Yes --> D{Does current tensor already satisfy requirement?} + + D -- Yes --> E[Use directly] + D -- No --> F{Can it be converted/lowered?} + + F -- Yes --> G[Convert during load/compile] + F -- No --> H[Raise load/compile error] + + C --> I[Dispatch by resolved dtype] + E --> I + G --> I +DType Constraints as Policies +A dtype annotation should be modeled as a policy rather than a simple claim about storage. + +Useful policy categories include: + +Any +No specific dtype is required. + +The tensor may keep the source dtype. + +Require +A hard requirement. + +The executable graph is invalid unless the tensor is available in the required dtype and layout. + +Prefer +A soft requirement. + +The runtime should use the preferred dtype if available or cheap to produce, but may fall back to another supported dtype. + +One-of +A restricted set of acceptable dtypes. + +The runtime may choose any supported dtype from the allowed set. + +Native Int8 vs GGUF Quantized Formats +Native int8 and GGUF quantized formats must not be treated as equivalent. + +A GGUF Q8 tensor may be stored using int8-like values internally, but the tensor contract usually includes quantization metadata, block-level scales, and GGUF-specific layout semantics. + +A native int8 tensor for an NPU is an executable representation expected by that backend. It may require different layout, scale handling, alignment, calibration, or memory placement. + +Therefore: + +GGUF Q8 != native int8 +GGUF Q4_K != native int8 +packed quantized storage != executable integer tensor contract +The system should represent this distinction explicitly. + +flowchart LR + A[GGUF Q8 / Q4_K] --> B[Quantized tensor format] + B --> C[Has packing, block metadata, scales] + + D[Native int8] --> E[Backend execution format] + E --> F[Has backend-specific layout and quant contract] + + C -. not equivalent .- F +Logical Shape vs Physical Storage +Logical shape must be part of the tensor contract. + +Physical storage should not define graph-visible shape. + +For example, a packed quantized tensor may occupy a one-dimensional byte segment internally, but the graph should see the tensor as its logical multidimensional shape. + +flowchart TD + A[Quantized tensor] --> B[Logical shape] + A --> C[Physical storage] + + B --> D[Graph contract] + C --> E[Implementation detail] + + D --> F[Shape inference] + D --> G[Op validation] + E --> H[Kernel-specific decoding] +The engine should avoid designs where a tensor appears as a 1D byte array at load time and is later patched into a logical 2D shape. The loader should produce properly shaped logical tensors directly. + +Load and Compile Pipeline +The recommended pipeline is: + +flowchart TD + A[Read model file] --> B[Create tensors with source dtype and logical shape] + B --> C[Build graph from DSL] + C --> D[Attach optional dtype constraints] + D --> E[Check backend capabilities] + E --> F{All constraints satisfied?} + + F -- Already satisfied --> G[Use tensors as-is] + F -- Convertible --> H[Lower tensors] + F -- Not satisfiable --> I[Fail before execution] + + G --> J[Resolved executable graph] + H --> J + J --> K[Forward execution] +Constraint resolution should happen before execution. Forward execution should not need to discover that a tensor cannot run on the selected backend. + +Backend Behavior +CPU / SIMD Backend +A general CPU backend should prefer adaptive execution. + +It can keep GGUF source dtypes when suitable kernels exist. + +flowchart TD + A[CPU / SIMD backend] --> B[Accept multiple dtypes] + B --> C[Keep source dtype where possible] + C --> D[Dispatch on resolved tensor dtype] +Restricted NPU Backend +A restricted backend should declare supported executable dtypes and layouts. + +If required tensors are not already in that form, they must be lowered before execution. + +flowchart TD + A[NPU backend] --> B[Supports limited executable formats] + B --> C[Require native dtype/layout] + C --> D[Lower tensors before execution] + D --> E[Dispatch to NPU kernel] +Compiled Execution Path +The adaptive tensor-engine path should remain the default execution model for flexible GGUF loading and heterogeneous quantization. + +However, the system may also expose a separate compiled execution path. In this context, the DSL itself is not the final artifact. The DSL is converted into a resolved DAG, and the resolved DAG may then be converted into an executable plan or lowered further into StableHLO, MLIR, or native optimized code. + +The recommended model is: + +DSL source +→ resolved DAG +→ executable plan +→ optional backend lowering +flowchart LR + A[DSL definition] --> B[Graph builder] + B --> C[Resolved DAG] + C --> D[Execution planner] + D --> E[Executable plan] + E --> F[Runtime / backend execution] +The resolved DAG should contain: + +op nodes +tensor edges +logical shapes +source dtypes +resolved execution dtypes +layouts +backend assignments +quantization metadata +explicit conversion or lowering nodes +dtype constraints and validation results +Example before dtype/backend resolution: + +flowchart TD + A[input: F16] --> B[rms_norm] + B --> C[linear_project] + W1[weight: Q4_K] --> C + C --> D[activation] + D --> E[linear_project] + W2[weight: Q8_0] --> E + E --> F[output: F16] +Example after resolution for a backend requiring native int8 weights: + +flowchart TD + A[input: F16] --> B[rms_norm] + B --> C[linear_project] + + W1[weight: Q4_K] --> L1[lower Q4_K to int8] + L1 --> C + + C --> D[activation] + D --> E[linear_project] + + W2[weight: Q8_0] --> L2[lower Q8_0 to int8] + L2 --> E + + E --> F[output: F16] +The compiled path has a different purpose from the adaptive runtime path. + +The adaptive path is optimized for flexibility: + +load many GGUF variants with the same DSL +preserve source dtypes where possible +dispatch dynamically based on resolved tensor dtype +support mixed quantization without requiring a separate model definition +The compiled path is optimized for stable, specialized execution: + +freeze dtype and layout decisions before execution +lower the graph into a resolved DAG +schedule the DAG into an executable plan +optionally lower the plan into StableHLO, MLIR, or native optimized code +allow aggressive fusion, layout planning, memory planning, and static validation +flowchart TD + A[DSL graph] --> B[Constraint resolution] + B --> C[Resolved DAG with dtypes and shapes] + C --> D[Executable plan] + D --> E{Optional external compiler path} + E -- No --> F[Tensor-engine execution] + E -- Yes --> G[StableHLO / MLIR] + G --> H[Backend optimization] + H --> I[Native optimized artifact] +The compiled execution path may produce several levels of artifact: + +Resolved DAG: normalized graph with explicit tensor flow and dtype/layout metadata. +Executable tensor-engine plan: scheduled graph with selected kernels and planned buffers. +StableHLO / MLIR module: compiler IR for external optimization and backend lowering. +Native backend artifact: JIT function, shared library, command buffer, serialized runtime module, or backend-specific executable blob. +This means the system can support two complementary modes: + +flowchart LR + A[DSL + model file] --> B{Execution mode} + + B -- Adaptive runtime --> C[Tensor engine] + C --> D[Dynamic dtype/backend dispatch] + + B -- Compiled path --> E[Resolved DAG] + E --> F[Executable plan] + F --> G[Optional StableHLO / MLIR / native code] +The compiled path should be treated as an explicit lowering target, not as the default interpretation of the DSL. This keeps the normal GGUF path flexible while still allowing high-performance deployment when dtype, shape, layout, and backend contracts are stable enough to compile. + +Lowering Phase Placement +Lowering should belong primarily to load/compile-time graph preparation, not ordinary forward execution. + +The recommended split is: + +Loading: + read the file and create logical tensors + +Compilation / graph preparation: + resolve constraints, insert conversions, select layouts, select kernels + +Lowering: + convert resolved graph/tensors/ops into the representation required by the selected backend + +Execution: + run the already-lowered executable plan +flowchart TD + A[Load model file] --> B[Create logical tensors] + B --> C[Build DSL DAG] + C --> D[Resolve dtype/backend constraints] + D --> E[Lower tensors / ops / layouts] + E --> F[Create executable plan] + F --> G[Forward execution] +Lowering may happen at load time when the target backend is already known. + +load GGUF for NPU +→ immediately convert required tensors to int8/native layout +→ store lowered tensors +This provides early failure and simple execution, but is less flexible if the same loaded model should target multiple backends. + +Lowering may also happen during an explicit compile step. + +load GGUF once +→ keep source tensors +→ compile for CPU or NPU later +→ lower only for the selected target +This is more flexible and allows multiple lowered variants to be cached. + +Execution-time lowering should be avoided for hard requirements. If it is used, it should be treated as lazy or deferred compilation, not as normal forward execution. It must produce the same result as the explicit compile path and should cache the lowered result for subsequent executions. + +DType Safety +This design provides dtype safety by turning dtype compatibility into a graph preparation invariant. + +The prepared DAG or executable plan should contain only tensors, conversions, and ops whose dtype and layout contracts have been resolved and validated against the selected backend. + +flowchart LR + A[Source tensors] --> B[Build DAG] + B --> C[Resolve dtype constraints] + C --> D[Insert conversions/lowering] + D --> E[Validate backend kernels] + E --> F[DType-safe executable plan] + F --> G[Forward execution] +After graph preparation, forward execution should not perform dtype discovery. It should execute a plan that is already known to be valid. + +The safety guarantees are: + +every tensor has explicit dtype metadata +every op declares accepted input and output dtype contracts +every backend declares supported dtype/layout/kernel combinations +constraint resolution validates the graph before execution +required conversions are inserted explicitly +unsupported dtype combinations fail before forward execution +kernel dispatch uses resolved dtype, not wrapper-class identity +A valid executable node must satisfy all relevant contracts: + +flowchart TD + A[Node: linearProject] --> B{Input dtype valid?} + A --> C{Weight dtype valid?} + A --> D{Output dtype valid?} + A --> E{Backend kernel exists?} + + B --> F[Valid executable node] + C --> F + D --> F + E --> F +DType safety does not automatically imply precision safety. + +A conversion such as Q4_K to int8, FP16 to int8, or Q8 to int8 may be valid according to dtype rules while still being lossy. Lossy conversion, calibration, scale handling, and acceptable accuracy loss should be controlled by separate conversion and precision policies. + +The complete safety model includes: + +dtype safety: can the graph execute with these dtypes? +layout safety: does the backend understand this memory layout? +shape safety: do tensor dimensions match op contracts? +conversion safety: is this conversion allowed, calibrated, cached, and valid? +precision safety: is the accuracy loss acceptable? +Error Handling +Errors should occur as early as possible. + +Load/compile-time errors +These should occur when: + +a hard dtype constraint cannot be satisfied +no conversion path exists +the selected backend does not support the required dtype +required layout lowering is unavailable +logical tensor shape is incompatible with the target kernel +quantization metadata is insufficient for conversion +Forward-time errors +Forward-time errors should be limited to unexpected execution failures. + +They should not be used for ordinary dtype compatibility discovery. + +Forward execution should only operate on resolved tensors. + +Kernel Dispatch +Kernel dispatch should use explicit tensor metadata. + +Dispatch should be based on: + +operation kind +input dtype +weight dtype +output dtype +backend +layout +possibly quantization parameters +Dispatch should not depend on: + +marker classes +wrapper class identity +raw storage array type +physical byte-count shape +flowchart LR + A[Resolved tensor metadata] --> B[Kernel key] + B --> C[Dispatch table] + C --> D[Selected backend kernel] +Benefits +This design provides: + +one DSL definition for many quantized variants +clean support for mixed-precision GGUF +explicit dtype semantics +early failure for impossible backend constraints +backend-specific lowering without polluting the architecture DSL +cleaner shape inference +no dtype marker-class hacks +less ambiguity between packed quantized formats and native execution formats +a natural path for CPU, SIMD, and NPU backends +Tradeoffs +More complex constraint resolution +The loader or compiler must understand dtype policies and backend capabilities. + +More explicit dtype model +The tensor engine must represent dtype and layout as first-class metadata. + +Conversion cost +When strict constraints require conversion, load or compile time may increase. + +For example, Q4_K to native int8 may require dequantization and requantization. + +Potential precision loss +Some conversions are lossy. + +The system may need policy controls for whether lossy conversion is allowed. + +More backend capability metadata +Backends need to declare which dtypes, layouts, and conversions they support. + +Open Questions +Should dtype constraints live in the DSL, backend profile, or both? +Should lossy conversion require an explicit opt-in policy? +Should lowering happen at load time, compile time, or lazily before first execution? +Should lowered tensors be cached? +How should per-layer and per-tensor constraints be represented? +How should backend-specific layouts be named and versioned? +How should quantization metadata be preserved during lowering? +Should unsupported soft preferences warn or silently fall back? +Should graph optimization occur before or after dtype lowering? +How should mixed backend execution be represented? +Recommended Direction +The recommended direction is: + +adaptive by default +explicit dtype constraints when needed +DSL converted into a resolved DAG for prepared execution +constraints resolved before execution +conversion/lowering handled during load or compile preparation +execution-time lowering only as lazy/deferred compilation +hard requirements fail early +kernel dispatch uses real tensor dtype +logical shape belongs to the tensor contract +physical storage remains an implementation detail +compiled path may emit tensor-engine plans, StableHLO, MLIR, or native artifacts +In short: + +The DSL should define architecture first, while allowing explicit dtype requirements only where execution needs them. + +This gives the flexibility needed for GGUF and mixed quantization, while still supporting strict execution environments such as NPUs and explicit compiled targets such as StableHLO, MLIR, or native optimized code. + +Final Summary +A strict dtype DSL is clean for fixed execution environments, but too rigid for general GGUF loading. + +A fully adaptive DSL fits GGUF better, but needs explicit dtype metadata and a principled way to handle strict backend requirements. + +The proposed hybrid model keeps the DSL adaptive by default and adds dtype constraints as execution policies. This allows source tensors to follow the model file unless an op or backend requires otherwise. When constraints exist, the loader or compiler either lowers the tensor into the required dtype/layout or fails before execution. + +The compiled execution path should be understood as DSL-to-DAG preparation. The DSL is converted into a resolved DAG with explicit tensor flow, dtype/layout metadata, backend assignments, and conversion/lowering nodes. That DAG may then become an executable tensor-engine plan or be lowered further into StableHLO, MLIR, or native backend artifacts. + +This gives the system dtype safety by making dtype compatibility a graph preparation invariant. Forward execution consumes a resolved plan rather than discovering dtype compatibility dynamically. + +This avoids confusing source storage with executable representation and provides a cleaner foundation for CPU, SIMD, NPU, and compiled StableHLO/MLIR/native execution targets. diff --git a/skainet-backends/skainet-backend-api/src/commonMain/kotlin/sk/ainet/backend/api/kernel/KernelProvider.kt b/skainet-backends/skainet-backend-api/src/commonMain/kotlin/sk/ainet/backend/api/kernel/KernelProvider.kt index c510d5fc..a5934221 100644 --- a/skainet-backends/skainet-backend-api/src/commonMain/kotlin/sk/ainet/backend/api/kernel/KernelProvider.kt +++ b/skainet-backends/skainet-backend-api/src/commonMain/kotlin/sk/ainet/backend/api/kernel/KernelProvider.kt @@ -66,4 +66,41 @@ public interface KernelProvider { * this provider does not specialize Q8_0. Same fall-through pattern. */ public fun matmulQ8_0(): Q8_0MatmulKernel? = null + + /** + * Capability query: does this provider carry a kernel for + * [opName] with the given [dtypeKeys]? + * + * Returns `true` iff the corresponding per-kernel accessor on + * this interface returns non-null. Callers (constraint + * resolution, fail-fast dispatch, capability tables) use this + * to ask "do you support this combination?" without actually + * fetching the kernel. + * + * Convention: + * - For matmul, [dtypeKeys] is `[inputDtypeName, weightDtypeName]` + * using the same string names as [sk.ainet.lang.types.DType.name] + * for floats / ints (`"Float32"`, `"BFloat16"`, …) and the + * short canonical block-format names for quantized weights + * (`"Q4_K"`, `"Q8_0"`). + * - For ops that aren't matmul (future: SDPA, gather, RMSNorm…), + * providers can override this method to declare those kernels. + * + * The default body covers the four matmul accessors that exist + * on this interface today. Providers that ship additional + * kernels override and chain through `super.supports(...)` for + * the matmul base cases. + */ + public fun supports(opName: String, dtypeKeys: List): Boolean { + if (opName != "matmul" || dtypeKeys.size != 2) return false + val (input, weight) = dtypeKeys + if (input != "Float32") return false + return when (weight) { + "Float32" -> matmulFp32() != null + "BFloat16" -> matmulBf16() != null + "Q4_K" -> matmulQ4K() != null + "Q8_0" -> matmulQ8_0() != null + else -> false + } + } } diff --git a/skainet-backends/skainet-backend-api/src/jvmMain/kotlin/sk/ainet/backend/api/kernel/KernelStrictness.kt b/skainet-backends/skainet-backend-api/src/jvmMain/kotlin/sk/ainet/backend/api/kernel/KernelStrictness.kt new file mode 100644 index 00000000..ad472048 --- /dev/null +++ b/skainet-backends/skainet-backend-api/src/jvmMain/kotlin/sk/ainet/backend/api/kernel/KernelStrictness.kt @@ -0,0 +1,58 @@ +package sk.ainet.backend.api.kernel + +/** + * Process-global fail-fast policy for kernel resolution. + * + * The RFC's "fail before execution" principle says that if a graph + * can't find a registered kernel for an op/dtype combination, the + * problem should surface at the load / compile / dispatch boundary + * rather than as a silent perf regression at forward time (e.g. a + * Q4_K matmul that falls back to a scalar dequant + FP32 matmul). + * + * This object provides the smallest possible affordance: a system + * property (`-Dskainet.strict.kernels=true`) flips the runtime into + * fail-fast mode. Dispatch sites call [failIfStrict] right before + * they would otherwise silently fall back; the call is a no-op when + * strict mode is off, preserving the existing adaptive behaviour. + * + * Per-context strict mode (e.g. `DirectCpuExecutionContext.create(strict = true)`) + * is a follow-up that requires plumbing the flag through every + * platform's `platformDefaultCpuOpsFactory`. The system-property + * affordance is sufficient for tests, CI, and debugging — the + * primary use cases — and ships with zero cross-platform plumbing. + */ +public object KernelStrictness { + + /** System-property name that flips fail-fast mode on. */ + public const val SYSTEM_PROPERTY: String = "skainet.strict.kernels" + + /** + * Returns `true` when the runtime should fail fast instead of + * falling back. Reads the system property on every call so a + * test can toggle it via `System.setProperty(...)` between + * cases without restarting the JVM. + */ + public fun isEnabled(): Boolean = + System.getProperty(SYSTEM_PROPERTY) == "true" + + /** + * Throws [NoSuchKernelException] with the supplied message + * builder when strict mode is on; otherwise returns. The + * message builder is only invoked when the exception is going + * to be thrown, so callers can include expensive details + * (provider list, dtype tuples) without paying the cost in the + * default-adaptive path. + */ + public inline fun failIfStrict(message: () -> String) { + if (isEnabled()) throw NoSuchKernelException(message()) + } +} + +/** + * Raised by dispatch sites when [KernelStrictness] is enabled and + * no registered kernel matches the requested op/dtype combination. + * The message includes the op name, the failing dtype tuple, and + * the list of currently-registered providers so the operator can + * see exactly which capability is missing. + */ +public class NoSuchKernelException(message: String) : RuntimeException(message) diff --git a/skainet-backends/skainet-backend-cpu/src/jvmMain/kotlin/sk/ainet/exec/kernel/PanamaVectorQ8_0MatmulKernel.kt b/skainet-backends/skainet-backend-cpu/src/jvmMain/kotlin/sk/ainet/exec/kernel/PanamaVectorQ8_0MatmulKernel.kt index 7a31e708..f3de43d2 100644 --- a/skainet-backends/skainet-backend-cpu/src/jvmMain/kotlin/sk/ainet/exec/kernel/PanamaVectorQ8_0MatmulKernel.kt +++ b/skainet-backends/skainet-backend-cpu/src/jvmMain/kotlin/sk/ainet/exec/kernel/PanamaVectorQ8_0MatmulKernel.kt @@ -70,13 +70,37 @@ public object PanamaVectorQ8_0MatmulKernel : Q8_0MatmulKernel { var blockAccVec = FloatVector.zero(floatSpecies) var k = 0 - while (k < BLOCK_SIZE) { - val byteVec = ByteVector.fromArray(byteSpeciesForFloat, weight, codesBase + k) - @Suppress("UNCHECKED_CAST") - val codesVec = byteVec.castShape(floatSpecies, 0) as FloatVector - val inputVec = FloatVector.fromArray(floatSpecies, input, inputBase + k) - blockAccVec = inputVec.fma(codesVec, blockAccVec) - k += laneCount + if (laneCount == 4) { + // NEON (4-wide float species): each ByteVector load brings + // 8 bytes (SPECIES_64 is the smallest byte species). To + // consume all 8 — and to avoid reading 4 bytes past the + // codes region on the last iteration of a block — convert + // both halves via `castShape(species, part)` per load and + // step k by 8. + while (k < BLOCK_SIZE) { + val byteVec = ByteVector.fromArray(byteSpeciesForFloat, weight, codesBase + k) + @Suppress("UNCHECKED_CAST") + val codesLo = byteVec.castShape(floatSpecies, 0) as FloatVector + @Suppress("UNCHECKED_CAST") + val codesHi = byteVec.castShape(floatSpecies, 1) as FloatVector + val inLo = FloatVector.fromArray(floatSpecies, input, inputBase + k) + val inHi = FloatVector.fromArray(floatSpecies, input, inputBase + k + 4) + blockAccVec = inLo.fma(codesLo, blockAccVec) + blockAccVec = inHi.fma(codesHi, blockAccVec) + k += 8 + } + } else { + // AVX2 (8-wide): the SPECIES_64 load and the + // floatSpecies cast width match — one FMA per + // iteration, step k by `laneCount`. + while (k < BLOCK_SIZE) { + val byteVec = ByteVector.fromArray(byteSpeciesForFloat, weight, codesBase + k) + @Suppress("UNCHECKED_CAST") + val codesVec = byteVec.castShape(floatSpecies, 0) as FloatVector + val inputVec = FloatVector.fromArray(floatSpecies, input, inputBase + k) + blockAccVec = inputVec.fma(codesVec, blockAccVec) + k += laneCount + } } acc += blockAccVec.reduceLanes(VectorOperators.ADD) * d } diff --git a/skainet-backends/skainet-backend-cpu/src/jvmMain/kotlin/sk/ainet/exec/tensor/ops/DefaultCpuOpsJvm.kt b/skainet-backends/skainet-backend-cpu/src/jvmMain/kotlin/sk/ainet/exec/tensor/ops/DefaultCpuOpsJvm.kt index e9be28bd..703beebf 100644 --- a/skainet-backends/skainet-backend-cpu/src/jvmMain/kotlin/sk/ainet/exec/tensor/ops/DefaultCpuOpsJvm.kt +++ b/skainet-backends/skainet-backend-cpu/src/jvmMain/kotlin/sk/ainet/exec/tensor/ops/DefaultCpuOpsJvm.kt @@ -7,6 +7,7 @@ import sk.ainet.backend.api.kernel.Bf16MatmulKernel import sk.ainet.backend.api.kernel.Fp32MatmulKernel import sk.ainet.backend.api.kernel.KernelRegistry import sk.ainet.backend.api.kernel.KernelServiceLoader +import sk.ainet.backend.api.kernel.KernelStrictness import sk.ainet.backend.api.kernel.Q4KMatmulKernel import sk.ainet.backend.api.kernel.Q8_0MatmulKernel import sk.ainet.exec.kernel.ScalarBf16MatmulKernel @@ -137,6 +138,20 @@ internal class DefaultCpuOpsJvm( chooseQuantizedMatmul(a, b)?.let { return it } // Fallback to standard FP32 matmul chooseMatmul(a, b)?.let { return it } + // RFC fail-fast point: if `-Dskainet.strict.kernels=true`, surface + // the missing kernel here rather than letting `super.matmul` quietly + // pick the scalar dequant + FP32 fallback. The strictness check is + // a no-op when the property is unset, preserving the existing + // adaptive behaviour. + KernelStrictness.failIfStrict { + val inDt = a.dtype.simpleName ?: a.dtype.toString() + val wDt = b.dtype.simpleName ?: b.dtype.toString() + val providers = KernelRegistry.providers().joinToString { p -> + "${p.name}(priority=${p.priority}, available=${p.isAvailable()})" + }.ifEmpty { "" } + "matmul ($inDt × $wDt) has no SPI kernel; would silently fall back " + + "to super.matmul. Registered providers: $providers" + } return super.matmul(a, b) } diff --git a/skainet-backends/skainet-backend-cpu/src/jvmTest/kotlin/sk/ainet/exec/kernel/KernelProviderSupportsTest.kt b/skainet-backends/skainet-backend-cpu/src/jvmTest/kotlin/sk/ainet/exec/kernel/KernelProviderSupportsTest.kt new file mode 100644 index 00000000..cc68683b --- /dev/null +++ b/skainet-backends/skainet-backend-cpu/src/jvmTest/kotlin/sk/ainet/exec/kernel/KernelProviderSupportsTest.kt @@ -0,0 +1,98 @@ +package sk.ainet.exec.kernel + +import kotlin.test.AfterTest +import kotlin.test.BeforeTest +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFalse +import kotlin.test.assertTrue +import sk.ainet.backend.api.kernel.KernelRegistry + +/** + * Capability-query parity tests for the default + * `KernelProvider.supports(...)` implementation. The default body + * introspects the existing per-kernel accessors and reports + * `true` iff the accessor returns non-null; these tests confirm + * the introspection matches what `bestAvailable()?.matmul*()` + * actually returns for the two providers shipped today. + */ +class KernelProviderSupportsTest { + + @BeforeTest + fun setUp() = KernelRegistry.clearForTesting() + + @AfterTest + fun tearDown() = KernelRegistry.clearForTesting() + + @Test + fun panama_supports_matches_accessor_nullability() { + val p = PanamaVectorKernelProvider + assertEquals( + p.matmulFp32() != null, + p.supports("matmul", listOf("Float32", "Float32")), + "FP32 matmul support must mirror matmulFp32() != null", + ) + assertEquals( + p.matmulBf16() != null, + p.supports("matmul", listOf("Float32", "BFloat16")), + "BF16 matmul support must mirror matmulBf16() != null", + ) + assertEquals( + p.matmulQ4K() != null, + p.supports("matmul", listOf("Float32", "Q4_K")), + "Q4_K matmul support must mirror matmulQ4K() != null", + ) + assertEquals( + p.matmulQ8_0() != null, + p.supports("matmul", listOf("Float32", "Q8_0")), + "Q8_0 matmul support must mirror matmulQ8_0() != null", + ) + } + + @Test + fun scalar_supports_matches_accessor_nullability() { + val p = ScalarKernelProvider + assertEquals( + p.matmulFp32() != null, + p.supports("matmul", listOf("Float32", "Float32")), + ) + // Scalar declines quantized matmuls today; the capability query + // must agree. + assertEquals( + p.matmulQ4K() != null, + p.supports("matmul", listOf("Float32", "Q4_K")), + ) + } + + @Test + fun unknown_op_returns_false() { + assertFalse( + PanamaVectorKernelProvider.supports("sdpa", listOf("Float32", "Float32", "Float32")), + "supports() must return false for ops the provider does not advertise", + ) + } + + @Test + fun matmul_with_wrong_arity_returns_false() { + assertFalse( + PanamaVectorKernelProvider.supports("matmul", listOf("Float32")), + "matmul takes exactly two dtype keys (input, weight)", + ) + assertFalse( + PanamaVectorKernelProvider.supports("matmul", listOf("Float32", "Float32", "Float32")), + ) + } + + @Test + fun matmul_with_non_float32_input_returns_false() { + // The kernel SPI today only specializes FP32-input matmuls. + // BFloat16-input matmul is a future kernel; the capability + // query must say "no" until that kernel exists. + assertFalse( + PanamaVectorKernelProvider.supports("matmul", listOf("BFloat16", "Float32")), + ) + assertTrue( + PanamaVectorKernelProvider.supports("matmul", listOf("Float32", "Float32")), + ) + } +} diff --git a/skainet-backends/skainet-backend-cpu/src/jvmTest/kotlin/sk/ainet/exec/kernel/KernelStrictnessTest.kt b/skainet-backends/skainet-backend-cpu/src/jvmTest/kotlin/sk/ainet/exec/kernel/KernelStrictnessTest.kt new file mode 100644 index 00000000..fd4f9bcf --- /dev/null +++ b/skainet-backends/skainet-backend-cpu/src/jvmTest/kotlin/sk/ainet/exec/kernel/KernelStrictnessTest.kt @@ -0,0 +1,97 @@ +package sk.ainet.exec.kernel + +import kotlin.test.AfterTest +import kotlin.test.BeforeTest +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertTrue +import sk.ainet.backend.api.kernel.KernelStrictness +import sk.ainet.backend.api.kernel.NoSuchKernelException +import sk.ainet.context.DirectCpuExecutionContext +import sk.ainet.lang.tensor.Shape +import sk.ainet.lang.tensor.dsl.tensor +import sk.ainet.lang.types.FP32 + +/** + * Tests the [KernelStrictness] fail-fast affordance and verifies + * that the happy path (FP32 × FP32 matmul, which always finds an + * SPI kernel) does NOT throw when strict mode is enabled — the + * exception is only raised when dispatch genuinely can't resolve a + * kernel. + */ +class KernelStrictnessTest { + + private var savedProperty: String? = null + + @BeforeTest + fun saveProperty() { + savedProperty = System.getProperty(KernelStrictness.SYSTEM_PROPERTY) + } + + @AfterTest + fun restoreProperty() { + val saved = savedProperty + if (saved != null) { + System.setProperty(KernelStrictness.SYSTEM_PROPERTY, saved) + } else { + System.clearProperty(KernelStrictness.SYSTEM_PROPERTY) + } + } + + @Test + fun isEnabled_reads_system_property() { + System.clearProperty(KernelStrictness.SYSTEM_PROPERTY) + assertEquals(false, KernelStrictness.isEnabled(), "default must be off") + + System.setProperty(KernelStrictness.SYSTEM_PROPERTY, "true") + assertEquals(true, KernelStrictness.isEnabled(), "property = 'true' enables strict mode") + + System.setProperty(KernelStrictness.SYSTEM_PROPERTY, "false") + assertEquals(false, KernelStrictness.isEnabled(), "property = 'false' disables strict mode") + + System.setProperty(KernelStrictness.SYSTEM_PROPERTY, "yes") + assertEquals(false, KernelStrictness.isEnabled(), "only 'true' enables; other values are off") + } + + @Test + fun failIfStrict_is_noop_when_disabled() { + System.clearProperty(KernelStrictness.SYSTEM_PROPERTY) + var called = false + KernelStrictness.failIfStrict { + called = true + "should never be evaluated" + } + assertEquals(false, called, "message lambda must not run when strict is off") + } + + @Test + fun failIfStrict_throws_when_enabled() { + System.setProperty(KernelStrictness.SYSTEM_PROPERTY, "true") + val ex = assertFailsWith { + KernelStrictness.failIfStrict { "matmul (FP32 × Q4_K) has no SPI kernel" } + } + assertTrue( + ex.message?.contains("FP32 × Q4_K") == true, + "exception message must come from the lambda: '${ex.message}'", + ) + } + + @Test + fun fp32_matmul_does_not_throw_under_strict_mode() { + // The FP32 path always resolves via fp32MatmulKernel (falls back + // to ScalarMatmulKernel if registry is empty). Strict mode must + // NOT break the happy path. + System.setProperty(KernelStrictness.SYSTEM_PROPERTY, "true") + val ctx = DirectCpuExecutionContext.create() + val a = tensor(ctx, FP32::class) { + tensor { shape(2, 3) { from(1f, 2f, 3f, 4f, 5f, 6f) } } + } + val b = tensor(ctx, FP32::class) { + tensor { shape(3, 2) { from(1f, 2f, 3f, 4f, 5f, 6f) } } + } + // No throw expected — FP32 matmul has a resolved kernel. + val c = ctx.ops.matmul(a, b) + assertEquals(Shape(2, 2), c.data.shape, "happy-path FP32 matmul must work under strict mode") + } +} diff --git a/skainet-compile/skainet-compile-dag/src/commonMain/kotlin/sk/ainet/lang/graph/ResolvedComputeGraph.kt b/skainet-compile/skainet-compile-dag/src/commonMain/kotlin/sk/ainet/lang/graph/ResolvedComputeGraph.kt new file mode 100644 index 00000000..826c5b75 --- /dev/null +++ b/skainet-compile/skainet-compile-dag/src/commonMain/kotlin/sk/ainet/lang/graph/ResolvedComputeGraph.kt @@ -0,0 +1,126 @@ +package sk.ainet.lang.graph + +import sk.ainet.lang.types.BF16 +import sk.ainet.lang.types.DType +import sk.ainet.lang.types.FP16 +import sk.ainet.lang.types.FP32 +import sk.ainet.lang.types.FP64 +import sk.ainet.lang.types.Int16 +import sk.ainet.lang.types.Int32 +import sk.ainet.lang.types.Int64 +import sk.ainet.lang.types.Int8 + +/** + * Typed view over a [ComputeGraph] that exposes resolved dtype, + * layout, and backend metadata on edges and nodes. Sketches the + * RFC's "resolved DAG" concept (`rfc.md`, "Resolved DAG" section) + * without introducing a parallel IR — every accessor is a typed + * decode of metadata that already lives on the wrapped graph. + * + * Construction contract (from the `validate()` method): + * - Every edge's [GraphEdge.tensorSpec.dtype] string decodes to a + * known [DType]. Acts as the load/compile-time precondition the + * RFC calls out — forward execution can rely on this being true. + * - Every node carries `metadata["dtype_resolved"] == true`, proof + * that `DTypeConstraintResolutionPass` (W7) has walked the graph. + * + * `resolvedLayout` and `backendAssignment` are placeholders today + * (always `null`). The names exist so the HLO converter overload + * in W9 has stable hooks to read; future passes will populate them + * as layout planning and backend selection land. + */ +public class ResolvedComputeGraph(public val delegate: ComputeGraph) { + + /** All wrapped nodes. */ + public val nodes: List get() = delegate.nodes + + /** All wrapped edges. */ + public val edges: List get() = delegate.edges + + /** + * Resolved logical dtype for the edge identified by [edgeId], or + * `null` if the edge is unknown or carries a dtype string that + * doesn't decode to a registered [DType]. + */ + public fun resolvedDtype(edgeId: String): DType? { + val edge = edges.firstOrNull { it.id == edgeId } ?: return null + return parseDtype(edge.tensorSpec.dtype) + } + + /** + * Placeholder for the resolved memory layout. Returns `null` + * today — populated by future layout-planning passes. + */ + public fun resolvedLayout(edgeId: String): Layout? = null + + /** + * Placeholder for the backend assignment. Returns `null` today + * — populated by future multi-backend scheduling. + */ + public fun backendAssignment(nodeId: String): String? = null + + /** + * Precondition check for the resolved-DAG contract: + * - every edge has a parseable dtype + * - every node carries the `dtype_resolved` marker from + * [sk.ainet.compile.opt.passes.DTypeConstraintResolutionPass] + * + * Returns a [ResolvedGraphValidation] result rather than + * throwing, so callers can choose between hard-fail (W9 HLO + * converter) and soft-warn (debugging tooling). + */ + public fun validate(): ResolvedGraphValidation { + val errors = mutableListOf() + for (edge in edges) { + if (parseDtype(edge.tensorSpec.dtype) == null) { + errors += "edge '${edge.id}' has unparseable dtype '${edge.tensorSpec.dtype}'" + } + } + for (node in nodes) { + if (node.metadata["dtype_resolved"] != true) { + errors += "node '${node.id}' is missing dtype_resolved marker " + + "— run DTypeConstraintResolutionPass before wrapping in ResolvedComputeGraph" + } + } + return ResolvedGraphValidation(valid = errors.isEmpty(), errors = errors) + } + + /** + * Mirror of the alias-aware lookup used by the + * constraint-resolution pass. Kept here as a self-contained + * piece so this module doesn't pull in `skainet-compile-opt`. + */ + private fun parseDtype(dtypeStr: String): DType? = when (dtypeStr) { + "Float32", "FP32", "F32", "float32" -> FP32 + "Float16", "FP16", "F16", "float16" -> FP16 + "BFloat16", "BF16", "bf16" -> BF16 + "Float64", "FP64", "F64", "float64" -> FP64 + "Int8", "I8", "int8" -> Int8 + "Int16", "I16", "int16" -> Int16 + "Int32", "I32", "int32" -> Int32 + "Int64", "I64", "int64" -> Int64 + else -> null + } +} + +/** + * Placeholder for resolved memory-layout metadata. Concrete + * implementations (row-major, col-major, packed-block, native + * NPU layouts) come with the layout-planning pass. + */ +public interface Layout + +/** + * Result of [ResolvedComputeGraph.validate]. + */ +public data class ResolvedGraphValidation( + public val valid: Boolean, + public val errors: List, +) { + /** Throws if invalid — used by callers that want hard-fail behaviour. */ + public fun requireValid() { + require(valid) { + "ResolvedComputeGraph validation failed: " + errors.joinToString("; ") + } + } +} diff --git a/skainet-compile/skainet-compile-dag/src/commonTest/kotlin/sk/ainet/lang/graph/ResolvedComputeGraphTest.kt b/skainet-compile/skainet-compile-dag/src/commonTest/kotlin/sk/ainet/lang/graph/ResolvedComputeGraphTest.kt new file mode 100644 index 00000000..bf985f64 --- /dev/null +++ b/skainet-compile/skainet-compile-dag/src/commonTest/kotlin/sk/ainet/lang/graph/ResolvedComputeGraphTest.kt @@ -0,0 +1,98 @@ +package sk.ainet.lang.graph + +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertFalse +import kotlin.test.assertNull +import kotlin.test.assertTrue +import sk.ainet.lang.tensor.ops.GenericOperation +import sk.ainet.lang.tensor.ops.TensorSpec +import sk.ainet.lang.types.BF16 +import sk.ainet.lang.types.FP32 + +class ResolvedComputeGraphTest { + + private fun makeGraph( + edgeDtype: String = "Float32", + nodeMarkedResolved: Boolean = true, + ): ComputeGraph { + val g = DefaultComputeGraph() + val meta = if (nodeMarkedResolved) mapOf("dtype_resolved" to true) else emptyMap() + val a = g.addNode(GraphNode( + id = "a", operation = GenericOperation("input"), + inputs = emptyList(), + outputs = listOf(TensorSpec("a-out", listOf(4), edgeDtype)), + metadata = meta, + )) + val b = g.addNode(GraphNode( + id = "b", operation = GenericOperation("noop"), + inputs = listOf(TensorSpec("a-out", listOf(4), edgeDtype)), + outputs = listOf(TensorSpec("b-out", listOf(4), edgeDtype)), + metadata = meta, + )) + g.addEdge(GraphEdge("e1", a, b, tensorSpec = TensorSpec("e1", listOf(4), edgeDtype))) + return g + } + + @Test + fun resolvedDtype_decodes_canonical_name() { + val g = ResolvedComputeGraph(makeGraph(edgeDtype = "Float32")) + assertEquals(FP32, g.resolvedDtype("e1")) + } + + @Test + fun resolvedDtype_decodes_short_alias() { + val g = ResolvedComputeGraph(makeGraph(edgeDtype = "BF16")) + assertEquals(BF16, g.resolvedDtype("e1")) + } + + @Test + fun resolvedDtype_returns_null_for_unknown_edge() { + val g = ResolvedComputeGraph(makeGraph()) + assertNull(g.resolvedDtype("does-not-exist")) + } + + @Test + fun layout_and_backend_are_placeholders() { + val g = ResolvedComputeGraph(makeGraph()) + assertNull(g.resolvedLayout("e1"), "layout placeholder must return null today") + assertNull(g.backendAssignment("a"), "backend placeholder must return null today") + } + + @Test + fun validate_passes_for_well_formed_graph() { + val g = ResolvedComputeGraph(makeGraph()) + val result = g.validate() + assertTrue(result.valid) + assertEquals(emptyList(), result.errors) + } + + @Test + fun validate_fails_for_missing_resolved_marker() { + val g = ResolvedComputeGraph(makeGraph(nodeMarkedResolved = false)) + val result = g.validate() + assertFalse(result.valid) + assertTrue( + result.errors.any { it.contains("dtype_resolved") }, + "missing-marker error must mention the marker key: ${result.errors}", + ) + } + + @Test + fun validate_fails_for_unparseable_dtype() { + val g = ResolvedComputeGraph(makeGraph(edgeDtype = "imaginary")) + val result = g.validate() + assertFalse(result.valid) + assertTrue( + result.errors.any { it.contains("imaginary") }, + "errors must surface the unparseable dtype string: ${result.errors}", + ) + } + + @Test + fun requireValid_throws_on_invalid_graph() { + val g = ResolvedComputeGraph(makeGraph(nodeMarkedResolved = false)) + assertFailsWith { g.validate().requireValid() } + } +} diff --git a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/dag2hlo.kt b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/dag2hlo.kt index 406165e2..a8636e3b 100644 --- a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/dag2hlo.kt +++ b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/dag2hlo.kt @@ -55,6 +55,40 @@ public fun toStableHlo(graph: ComputeGraph, functionName: String = "main"): Stab return converter.convert(graph, functionName) } +/** + * Export a [sk.ainet.lang.graph.ResolvedComputeGraph] into a StableHLO + * MLIR module — the dtype-resolved entry point that the W7 + * `DTypeConstraintResolutionPass` produces. + * + * The contract this overload upholds vs the plain [ComputeGraph] + * variant: every edge's dtype has already been resolved to a typed + * [sk.ainet.lang.types.DType] (the wrapper's `validate()` would + * have caught any unparseable strings), and every node carries the + * `dtype_resolved` marker from the pass. Callers that flow through + * this overload get a precondition guarantee that the HLO emit + * step won't silently misinterpret a stray dtype string. + * + * Today the converter still consumes the underlying [ComputeGraph] — + * the wrapper is the *contract*, not a separate emit path. As + * future passes start writing layout / backend metadata into the + * resolved graph, the converter can read those typed accessors + * directly. This entry point gives them the stable hook to do so. + */ +public fun toStableHlo( + graph: sk.ainet.lang.graph.ResolvedComputeGraph, + functionName: String = "main", + validate: Boolean = true, +): StableHloModule { + if (validate) { + graph.validate().requireValid() + } + // Delegate to the underlying ComputeGraph emit path — same HLO output + // for graphs that pass validation. Future versions can branch here to + // consume `graph.resolvedLayout(edgeId)` / `graph.backendAssignment(nodeId)` + // once those passes ship. + return toStableHlo(graph.delegate, functionName) +} + /** * Legacy implementation for backward compatibility. * diff --git a/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/ResolvedComputeGraphHloTest.kt b/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/ResolvedComputeGraphHloTest.kt new file mode 100644 index 00000000..4a1b8406 --- /dev/null +++ b/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/ResolvedComputeGraphHloTest.kt @@ -0,0 +1,94 @@ +package sk.ainet.compile.hlo + +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import sk.ainet.lang.graph.DefaultComputeGraph +import sk.ainet.lang.graph.GraphEdge +import sk.ainet.lang.graph.GraphNode +import sk.ainet.lang.graph.ResolvedComputeGraph +import sk.ainet.lang.tensor.ops.GenericOperation +import sk.ainet.lang.tensor.ops.TensorSpec + +/** + * Round-trip and byte-equivalence tests for the + * `toStableHlo(ResolvedComputeGraph)` overload added in W9 of #615. + * + * Key property: when the underlying graph passes resolved-graph + * validation, the two HLO entry points must produce byte-identical + * output. The wrapper is the *contract*, not a separate emit path. + */ +class ResolvedComputeGraphHloTest { + + private fun buildSimpleGraph(): DefaultComputeGraph { + val g = DefaultComputeGraph() + val resolvedMeta = mapOf("dtype_resolved" to true) + val input = g.addNode( + GraphNode( + id = "input", + operation = GenericOperation("input"), + inputs = emptyList(), + outputs = listOf(TensorSpec("input-out", listOf(2, 3), "FP32")), + metadata = resolvedMeta, + ), + ) + val relu = g.addNode( + GraphNode( + id = "relu", + operation = GenericOperation("relu"), + inputs = listOf(TensorSpec("input-out", listOf(2, 3), "FP32")), + outputs = listOf(TensorSpec("relu-out", listOf(2, 3), "FP32")), + metadata = resolvedMeta, + ), + ) + g.addEdge(GraphEdge("e1", input, relu, tensorSpec = TensorSpec("e1", listOf(2, 3), "FP32"))) + return g + } + + @Test + fun resolved_overload_produces_same_module_as_plain_overload() { + val g = buildSimpleGraph() + val viaPlain = toStableHlo(g) + val viaResolved = toStableHlo(ResolvedComputeGraph(g)) + // functionName + content are byte-identical. + assertEquals(viaPlain.functionName, viaResolved.functionName) + assertEquals(viaPlain.content, viaResolved.content) + } + + @Test + fun resolved_overload_validates_by_default() { + // Build a graph that's missing the dtype_resolved marker. + val g = DefaultComputeGraph() + val node = g.addNode( + GraphNode( + id = "input", + operation = GenericOperation("input"), + inputs = emptyList(), + outputs = listOf(TensorSpec("input-out", listOf(2), "FP32")), + metadata = emptyMap(), + ), + ) + // Default validate = true must reject this. + assertFailsWith { + toStableHlo(ResolvedComputeGraph(g)) + } + } + + @Test + fun resolved_overload_skips_validation_when_opted_out() { + // Same unmarked graph, but validate = false. + val g = DefaultComputeGraph() + g.addNode( + GraphNode( + id = "input", + operation = GenericOperation("input"), + inputs = emptyList(), + outputs = listOf(TensorSpec("input-out", listOf(2), "FP32")), + metadata = emptyMap(), + ), + ) + // No throw expected. + val module = toStableHlo(ResolvedComputeGraph(g), validate = false) + assertEquals("main", module.functionName) + } +} diff --git a/skainet-compile/skainet-compile-opt/src/commonMain/kotlin/sk/ainet/compile/opt/GraphOptimizationPipeline.kt b/skainet-compile/skainet-compile-opt/src/commonMain/kotlin/sk/ainet/compile/opt/GraphOptimizationPipeline.kt index a59e2e8b..aa212d9a 100644 --- a/skainet-compile/skainet-compile-opt/src/commonMain/kotlin/sk/ainet/compile/opt/GraphOptimizationPipeline.kt +++ b/skainet-compile/skainet-compile-opt/src/commonMain/kotlin/sk/ainet/compile/opt/GraphOptimizationPipeline.kt @@ -2,6 +2,7 @@ package sk.ainet.compile.opt import sk.ainet.lang.graph.ComputeGraph import sk.ainet.compile.opt.passes.ConstantFoldingPass +import sk.ainet.compile.opt.passes.DTypeConstraintResolutionPass import sk.ainet.compile.opt.passes.DeadCodeEliminationPass import sk.ainet.compile.opt.passes.LLMFusionPass import sk.ainet.compile.opt.passes.OperationFusionPass @@ -73,6 +74,12 @@ public class GraphOptimizationPipeline( */ public fun createDefault(): GraphOptimizationPipeline = GraphOptimizationPipeline( passes = listOf( + // Resolve dtype constraints first so fusion / DCE / constant + // folding see the resolved-or-failed graph rather than a + // mix of policy-tagged and bare nodes. Per the RFC, this + // is the boundary where dtype problems surface — every + // later pass can assume dtype-validity. + DTypeConstraintResolutionPass(), DeadCodeEliminationPass(), ConstantFoldingPass(), OperationFusionPass() @@ -84,6 +91,7 @@ public class GraphOptimizationPipeline( */ public fun createAggressive(): GraphOptimizationPipeline = GraphOptimizationPipeline( passes = listOf( + DTypeConstraintResolutionPass(), DeadCodeEliminationPass(), ConstantFoldingPass(), OperationFusionPass() @@ -95,6 +103,7 @@ public class GraphOptimizationPipeline( * Creates an LLM-optimized pipeline with transformer-specific passes. * * Pass ordering: + * 0. DTypeConstraintResolution — resolve dtype policies before fusion * 1. TransposeElimination — fold transposes into matmuls * 2. SharedWeightDedup — deduplicate tied weights (e.g. token_embd ↔ output) * 3. LLMFusion — fuse RMSNorm, SwiGLU, QKV patterns @@ -103,6 +112,7 @@ public class GraphOptimizationPipeline( */ public fun createLLM(): GraphOptimizationPipeline = GraphOptimizationPipeline( passes = listOf( + DTypeConstraintResolutionPass(), TransposeEliminationPass(), SharedWeightDeduplicationPass(), LLMFusionPass(), diff --git a/skainet-compile/skainet-compile-opt/src/commonMain/kotlin/sk/ainet/compile/opt/passes/DTypeConstraintResolutionPass.kt b/skainet-compile/skainet-compile-opt/src/commonMain/kotlin/sk/ainet/compile/opt/passes/DTypeConstraintResolutionPass.kt new file mode 100644 index 00000000..8002523f --- /dev/null +++ b/skainet-compile/skainet-compile-opt/src/commonMain/kotlin/sk/ainet/compile/opt/passes/DTypeConstraintResolutionPass.kt @@ -0,0 +1,153 @@ +package sk.ainet.compile.opt.passes + +import sk.ainet.compile.opt.GraphOptimizationPass +import sk.ainet.compile.opt.GraphOptimizationResult +import sk.ainet.lang.graph.ComputeGraph +import sk.ainet.lang.graph.GraphNode +import sk.ainet.lang.types.BF16 +import sk.ainet.lang.types.DType +import sk.ainet.lang.types.DTypePolicy +import sk.ainet.lang.types.FP16 +import sk.ainet.lang.types.FP32 +import sk.ainet.lang.types.FP64 +import sk.ainet.lang.types.Int8 +import sk.ainet.lang.types.Int16 +import sk.ainet.lang.types.Int32 +import sk.ainet.lang.types.Int64 + +/** + * Pass that enforces per-node [DTypePolicy] constraints attached to + * graph nodes (via the `dag { … dtypePolicy(…) }` DSL extension from + * W6 of #615). Implements the RFC's "fail before execution" rule — + * any [DTypePolicy.Require] that can't be satisfied raises + * [DtypeConstraintViolationException] *here*, at graph-prep time, + * not at forward execution. + * + * Policy semantics: + * - `Any`: never visited; nodes without an attached policy are + * passed through. + * - `Require(target)`: every input edge to the node MUST already + * have dtype matching `target`. Mismatch throws + * [DtypeConstraintViolationException]. + * - `Prefer(target)`: input dtype matching `target` is preferred; + * mismatches emit a diagnostic but do not fail. + * - `OneOf(allowed)`: every input edge's dtype MUST already be in + * `allowed`. Mismatch throws. + * + * **Scope intentionally narrow.** This pass does not insert cast + * nodes today — when a `Require` mismatches, it fails fast (which + * is the RFC's prescribed behaviour when no cast kernel exists). + * Cast-node insertion is a follow-up that ships alongside concrete + * cast kernels (Q4_K → Int8, FP32 → BF16, …). See the + * out-of-scope section of issue #615. + * + * Side effect on the graph: visited nodes get + * `metadata["dtype_resolved"] = true` so downstream passes (and the + * future `ResolvedComputeGraph` wrapper from W8) can confirm the + * pass has run. + */ +public class DTypeConstraintResolutionPass : GraphOptimizationPass { + + override val name: String = "dtype-constraint-resolution" + + override fun apply(graph: ComputeGraph): GraphOptimizationResult { + val diagnostics = mutableListOf() + var changed = false + + for (node in graph.nodes) { + val policy = node.metadata[POLICY_KEY] as? DTypePolicy ?: continue + val inputDtypes = node.inputs.map { it.dtype } + + when (policy) { + DTypePolicy.Any -> { /* permissive; no-op */ } + + is DTypePolicy.Require -> { + val targetName = policy.target.name + for ((i, dtypeStr) in inputDtypes.withIndex()) { + if (!dtypeStringMatches(dtypeStr, policy.target)) { + throw DtypeConstraintViolationException( + "Node '${node.id}' (${node.operationName}) declares " + + "DTypePolicy.Require($targetName) but input $i has dtype '$dtypeStr'. " + + "Cast kernels are not registered for this conversion; resolve at the " + + "loader (e.g. SafeTensorsParametersLoader.withPolicy) or change the " + + "policy to Prefer/OneOf to permit fallback." + ) + } + } + } + + is DTypePolicy.Prefer -> { + val targetName = policy.target.name + for ((i, dtypeStr) in inputDtypes.withIndex()) { + if (!dtypeStringMatches(dtypeStr, policy.target)) { + diagnostics += "Node '${node.id}' (${node.operationName}) prefers " + + "$targetName but input $i has dtype '$dtypeStr' — using the existing dtype." + } + } + } + + is DTypePolicy.OneOf -> { + val allowedNames = policy.allowed.joinToString { it.name } + for ((i, dtypeStr) in inputDtypes.withIndex()) { + if (policy.allowed.none { dtypeStringMatches(dtypeStr, it) }) { + throw DtypeConstraintViolationException( + "Node '${node.id}' (${node.operationName}) declares " + + "DTypePolicy.OneOf($allowedNames) but input $i has dtype " + + "'$dtypeStr' which is outside the allowed set. Cast kernels " + + "are not registered; resolve at the loader." + ) + } + } + } + } + + // Mark the node as resolved by this pass. Use copy to keep + // the immutable-copy convention the other passes follow. + val resolved = node.copy(metadata = node.metadata + (RESOLVED_KEY to true)) + graph.removeNode(node) + graph.addNode(resolved) + changed = true + } + + return GraphOptimizationResult(graph, changed = changed, diagnostics = diagnostics) + } + + /** + * Matches the string form used by [sk.ainet.lang.tensor.ops.TensorSpec.dtype] + * against a typed [DType]. Handles both registry-canonical names + * (`"Float32"`, `"BFloat16"`) and the short class-derived + * aliases produced by the DAG DSL's `dtypeName()` helper (`"FP32"`, + * `"BF16"`, `"Int8"`, …). + */ + internal fun dtypeStringMatches(dtypeStr: String, dtype: DType): Boolean { + if (dtypeStr == dtype.name) return true + return when (dtype) { + FP32 -> dtypeStr == "FP32" || dtypeStr == "F32" + FP16 -> dtypeStr == "FP16" || dtypeStr == "F16" + BF16 -> dtypeStr == "BF16" + FP64 -> dtypeStr == "FP64" || dtypeStr == "F64" + Int8 -> dtypeStr == "Int8" || dtypeStr == "I8" + Int16 -> dtypeStr == "Int16" || dtypeStr == "I16" + Int32 -> dtypeStr == "Int32" || dtypeStr == "I32" + Int64 -> dtypeStr == "Int64" || dtypeStr == "I64" + else -> false + } + } + + public companion object { + /** Attribute key shared with the DSL extension (W6). */ + public const val POLICY_KEY: String = "dtype_policy" + + /** Marker the pass writes onto every node it visits. */ + public const val RESOLVED_KEY: String = "dtype_resolved" + } +} + +/** + * Raised when [DTypeConstraintResolutionPass] cannot satisfy a hard + * [DTypePolicy.Require] (or `OneOf` rejection) and no cast kernel + * is available to bridge the gap. Surfaces dtype problems at + * graph-prep time, before forward execution — exactly the RFC's + * "fail before execution" boundary. + */ +public class DtypeConstraintViolationException(message: String) : RuntimeException(message) diff --git a/skainet-compile/skainet-compile-opt/src/commonTest/kotlin/sk/ainet/compile/opt/DTypeConstraintResolutionPassTest.kt b/skainet-compile/skainet-compile-opt/src/commonTest/kotlin/sk/ainet/compile/opt/DTypeConstraintResolutionPassTest.kt new file mode 100644 index 00000000..e2b52b3b --- /dev/null +++ b/skainet-compile/skainet-compile-opt/src/commonTest/kotlin/sk/ainet/compile/opt/DTypeConstraintResolutionPassTest.kt @@ -0,0 +1,120 @@ +package sk.ainet.compile.opt + +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertFalse +import kotlin.test.assertTrue +import sk.ainet.compile.opt.passes.DTypeConstraintResolutionPass +import sk.ainet.compile.opt.passes.DtypeConstraintViolationException +import sk.ainet.lang.graph.DefaultComputeGraph +import sk.ainet.lang.graph.GraphNode +import sk.ainet.lang.tensor.ops.GenericOperation +import sk.ainet.lang.tensor.ops.TensorSpec +import sk.ainet.lang.types.BF16 +import sk.ainet.lang.types.DTypePolicy +import sk.ainet.lang.types.FP32 +import sk.ainet.lang.types.Int8 + +class DTypeConstraintResolutionPassTest { + + private fun node( + id: String, + opName: String = "matmul", + inputDtype: String = "Float32", + policy: DTypePolicy? = null, + ): GraphNode { + val meta = if (policy != null) mapOf(DTypeConstraintResolutionPass.POLICY_KEY to policy) else emptyMap() + return GraphNode( + id = id, + operation = GenericOperation(opName), + inputs = listOf(TensorSpec(name = "$id-in", shape = listOf(4, 4), dtype = inputDtype)), + outputs = listOf(TensorSpec(name = "$id-out", shape = listOf(4, 4), dtype = inputDtype)), + metadata = meta, + ) + } + + @Test + fun nodes_without_policy_are_passed_through() { + val g = DefaultComputeGraph() + g.addNode(node("n0")) + g.addNode(node("n1")) + val result = DTypeConstraintResolutionPass().apply(g) + assertFalse(result.changed, "no policy = no work") + // Neither node should be marked resolved (only visited nodes get the marker). + assertEquals(emptyList(), result.graph.nodes.filter { it.metadata.containsKey(DTypeConstraintResolutionPass.RESOLVED_KEY) }) + } + + @Test + fun any_policy_passes_through() { + val g = DefaultComputeGraph() + g.addNode(node("n0", policy = DTypePolicy.Any)) + val result = DTypeConstraintResolutionPass().apply(g) + assertTrue(result.changed, "the resolved-marker write counts as a change") + val n = result.graph.nodes.single() + assertTrue(n.metadata[DTypeConstraintResolutionPass.RESOLVED_KEY] == true) + } + + @Test + fun require_matching_dtype_passes() { + val g = DefaultComputeGraph() + g.addNode(node("n0", inputDtype = "Float32", policy = DTypePolicy.Require(FP32))) + val result = DTypeConstraintResolutionPass().apply(g) + assertTrue(result.changed) + } + + @Test + fun require_mismatched_dtype_fails_fast() { + val g = DefaultComputeGraph() + g.addNode(node("n0", inputDtype = "Float32", policy = DTypePolicy.Require(BF16))) + val ex = assertFailsWith { + DTypeConstraintResolutionPass().apply(g) + } + val msg = ex.message ?: "" + assertTrue(msg.contains("BFloat16"), "msg must name the required dtype: $msg") + assertTrue(msg.contains("Float32"), "msg must name the actual input dtype: $msg") + assertTrue(msg.contains("Cast kernels"), "msg must hint at the resolution path: $msg") + } + + @Test + fun require_mismatched_dtype_with_short_alias_also_resolves() { + // DAG DSL emits dtype strings like "FP32" / "BF16" via dtypeName(). + // The pass must handle both the registry canonical name and the short alias. + val g = DefaultComputeGraph() + g.addNode(node("n0", inputDtype = "FP32", policy = DTypePolicy.Require(FP32))) + val result = DTypeConstraintResolutionPass().apply(g) + assertTrue(result.changed, "alias 'FP32' must satisfy Require(FP32)") + } + + @Test + fun prefer_mismatched_dtype_emits_diagnostic_no_throw() { + val g = DefaultComputeGraph() + g.addNode(node("n0", inputDtype = "Float32", policy = DTypePolicy.Prefer(BF16))) + val result = DTypeConstraintResolutionPass().apply(g) + assertTrue(result.changed) + assertTrue( + result.diagnostics.any { it.contains("prefers") && it.contains("BFloat16") }, + "diagnostic must mention the preference: ${result.diagnostics}", + ) + } + + @Test + fun oneOf_in_set_passes() { + val g = DefaultComputeGraph() + g.addNode(node("n0", inputDtype = "Float32", policy = DTypePolicy.OneOf(setOf(FP32, BF16)))) + val result = DTypeConstraintResolutionPass().apply(g) + assertTrue(result.changed) + } + + @Test + fun oneOf_outside_set_fails_fast() { + val g = DefaultComputeGraph() + g.addNode(node("n0", inputDtype = "Float32", policy = DTypePolicy.OneOf(setOf(BF16, Int8)))) + val ex = assertFailsWith { + DTypeConstraintResolutionPass().apply(g) + } + val msg = ex.message ?: "" + assertTrue(msg.contains("OneOf"), msg) + assertTrue(msg.contains("Float32"), msg) + } +} diff --git a/skainet-io/skainet-io-gguf/src/commonMain/kotlin/sk/ainet/io/gguf/GGUFModelReader.kt b/skainet-io/skainet-io-gguf/src/commonMain/kotlin/sk/ainet/io/gguf/GGUFModelReader.kt index 30ae3cfa..b442f880 100644 --- a/skainet-io/skainet-io-gguf/src/commonMain/kotlin/sk/ainet/io/gguf/GGUFModelReader.kt +++ b/skainet-io/skainet-io-gguf/src/commonMain/kotlin/sk/ainet/io/gguf/GGUFModelReader.kt @@ -3,19 +3,50 @@ package sk.ainet.io.gguf import sk.ainet.io.ModelReader import sk.ainet.io.TensorInfo import sk.ainet.lang.tensor.data.TensorData -import sk.ainet.lang.tensor.Shape -class GGUFModelReader : ModelReader { - override val metadata: Map = mutableMapOf() - override val tensors: Map = mutableMapOf() +/** + * **Legacy facade — use [StreamingGgufParametersLoader] for GGUF loading.** + * + * The pull-style `ModelReader.loadTensor(name)` contract returns a + * raw `TensorData<*, *>` without an `ExecutionContext`. The + * production GGUF loader ([StreamingGgufParametersLoader]) is a + * push-style API bound to a context — it iterates every tensor in + * the file, dispatches per source dtype (F32, F16, BF16, Q4_K, + * Q8_0…) into the matching `TensorData` subtype with explicit + * logical shape, and calls back into user code per tensor. + * + * That push-style API is the right shape for GGUF loading: GGUF + * files store all tensor headers contiguously up front, so iterating + * once is more efficient than seeking back into the file per + * `loadTensor(name)` call. New consumers should construct a + * [StreamingGgufParametersLoader] directly. + * + * This class is kept compiling so existing dependants don't break, + * but `loadTensor` is intentionally a fail-fast stub — using the + * legacy facade for actual GGUF loading silently corrupted dtype + * metadata (no policy hook, no shape verification) and is exactly + * the anti-pattern the dtype-policy RFC (#615) calls out. + */ +@Deprecated( + message = "Use sk.ainet.io.gguf.StreamingGgufParametersLoader instead. " + + "The streaming loader preserves source dtypes (Q4_K, Q8_0, etc.) as packed " + + "TensorData subtypes and threads through DTypePolicy for fail-fast resolution.", + replaceWith = ReplaceWith("StreamingGgufParametersLoader"), +) +public class GGUFModelReader : ModelReader { + override val metadata: Map = emptyMap() + override val tensors: Map = emptyMap() override suspend fun loadTensor(name: String): TensorData<*, *> { - val info = tensors[name] ?: error("Tensor $name not found") - // Implementation will use MemoryMappedFileChunk or similar to slice the data - TODO("Not yet implemented: streaming tensor loading for GGUF") + error( + "GGUFModelReader is a legacy facade and does not load GGUF tensors. " + + "Use StreamingGgufParametersLoader(sourceProvider).load(ctx, dtype, onTensorLoaded) " + + "to iterate tensors with their source dtypes preserved (Q4_K, Q8_0, F32, etc.). " + + "See contributing/dtype-model.adoc for the GGUF dtype-mapping reference.", + ) } override fun close() { - // Close underlying resources (e.g. mapped file) + // Nothing to close — this facade owns no resources. } } diff --git a/skainet-io/skainet-io-gguf/src/commonMain/kotlin/sk/ainet/io/gguf/StreamingGgufParametersLoader.kt b/skainet-io/skainet-io-gguf/src/commonMain/kotlin/sk/ainet/io/gguf/StreamingGgufParametersLoader.kt index 00be76c2..324c5da7 100644 --- a/skainet-io/skainet-io-gguf/src/commonMain/kotlin/sk/ainet/io/gguf/StreamingGgufParametersLoader.kt +++ b/skainet-io/skainet-io-gguf/src/commonMain/kotlin/sk/ainet/io/gguf/StreamingGgufParametersLoader.kt @@ -7,7 +7,10 @@ import sk.ainet.lang.tensor.Shape import sk.ainet.lang.tensor.Tensor import sk.ainet.lang.tensor.data.Q4_KBlockTensorData import sk.ainet.lang.tensor.data.Q8_0BlockTensorData +import sk.ainet.lang.types.BF16 import sk.ainet.lang.types.DType +import sk.ainet.lang.types.DTypePolicy +import sk.ainet.lang.types.FP16 import sk.ainet.lang.types.FP32 import sk.ainet.lang.types.Int32 import kotlin.reflect.KClass @@ -151,6 +154,72 @@ public class StreamingGgufParametersLoader( } } + public companion object { + + /** + * Convenience constructor that takes a [DTypePolicy] and + * validates it against the dtypes the GGUF loader supports + * today. The validator runs eagerly — if the requested + * policy can never be satisfied by this loader (e.g. + * `Require(Int8)` against a GGUF file: this loader doesn't + * cast), an [IllegalArgumentException] is raised before the + * loader is constructed, exactly matching the RFC's + * "fail before execution" rule. + * + * Current per-source behaviour the validator enforces: + * - GGUF `F32` / `I32` / `Q4_K` / `Q8_0` are always + * preserved verbatim — any policy that admits the + * matching dtype passes. + * - GGUF `F16` / `BF16` always dequant to FP32 in this + * loader today (no KEEP_NATIVE GGUF path yet). A policy + * of `Require(BF16)` or `Require(FP16)` therefore fails + * eagerly; use `Any`, `Prefer`, or `OneOf` containing + * `FP32` if you want the adaptive dequant behaviour. + * + * The validator is conservative — it doesn't open the GGUF + * file to check which dtypes are actually present. A + * policy that's satisfiable in principle but happens to + * conflict with the specific file's tensors will surface at + * iteration time via the `null`-return path in [load]. + */ + public fun withPolicy( + sourceProvider: () -> RandomAccessSource, + policy: DTypePolicy, + onProgress: (current: Long, total: Long, message: String?) -> Unit = { _, _, _ -> }, + ): StreamingGgufParametersLoader { + validatePolicy(policy) + return StreamingGgufParametersLoader(sourceProvider, onProgress) + } + + internal fun validatePolicy(policy: DTypePolicy) { + when (policy) { + DTypePolicy.Any -> Unit + is DTypePolicy.Prefer -> Unit + is DTypePolicy.OneOf -> Unit + is DTypePolicy.Require -> when (policy.target) { + FP32 -> Unit + BF16 -> throw IllegalArgumentException( + "StreamingGgufParametersLoader: Require(BF16) is not supported — " + + "GGUF BF16 sources are dequanted to FP32 by this loader today (no KEEP_NATIVE " + + "GGUF path yet). Use Any or Prefer(BF16) to accept the dequant fallback, or " + + "wait for the policy-aware GGUF reader to land.", + ) + FP16 -> throw IllegalArgumentException( + "StreamingGgufParametersLoader: Require(FP16) is not supported — " + + "GGUF F16 sources are dequanted to FP32 by this loader today (no Fp16DenseTensorData " + + "backing yet). Use Any or Prefer(FP16) to accept the dequant fallback.", + ) + else -> throw IllegalArgumentException( + "StreamingGgufParametersLoader: Require(${policy.target.name}) is not satisfiable — " + + "this loader produces FP32 / Int32 / Q4_K / Q8_0 tensors only, and does not cast " + + "between source dtypes. Use Any to inherit the source dtype, or open a follow-up " + + "to add a ${policy.target.name} cast path.", + ) + } + } + } + } + private fun halfToFloat(hbits: Int): Float { val sign = (hbits and 0x8000) shl 16 val exp = (hbits and 0x7C00) shr 10 diff --git a/skainet-io/skainet-io-gguf/src/commonTest/kotlin/sk/ainet/io/gguf/StreamingGgufParametersLoaderPolicyTest.kt b/skainet-io/skainet-io-gguf/src/commonTest/kotlin/sk/ainet/io/gguf/StreamingGgufParametersLoaderPolicyTest.kt new file mode 100644 index 00000000..bf32f5a0 --- /dev/null +++ b/skainet-io/skainet-io-gguf/src/commonTest/kotlin/sk/ainet/io/gguf/StreamingGgufParametersLoaderPolicyTest.kt @@ -0,0 +1,70 @@ +package sk.ainet.io.gguf + +import kotlin.test.Test +import kotlin.test.assertFailsWith +import kotlin.test.assertTrue +import sk.ainet.lang.types.BF16 +import sk.ainet.lang.types.DTypePolicy +import sk.ainet.lang.types.FP16 +import sk.ainet.lang.types.FP32 +import sk.ainet.lang.types.Int8 + +/** + * Unit tests for `StreamingGgufParametersLoader.validatePolicy` — the + * eager check inside the `withPolicy` factory that fails fast when + * a requested [DTypePolicy] can never be satisfied by the loader's + * current capabilities. Mirrors the SafeTensors policy-adapter test + * (W0b) for the GGUF side (W0c of #615). + */ +class StreamingGgufParametersLoaderPolicyTest { + + @Test + fun any_passes_validation() { + // No throw. + StreamingGgufParametersLoader.validatePolicy(DTypePolicy.Any) + } + + @Test + fun require_fp32_passes_validation() { + StreamingGgufParametersLoader.validatePolicy(DTypePolicy.Require(FP32)) + } + + @Test + fun require_bf16_fails_fast_with_clear_message() { + val ex = assertFailsWith { + StreamingGgufParametersLoader.validatePolicy(DTypePolicy.Require(BF16)) + } + val msg = ex.message ?: "" + assertTrue(msg.contains("Require(BF16)"), msg) + assertTrue(msg.contains("KEEP_NATIVE"), msg) + } + + @Test + fun require_fp16_fails_fast_with_clear_message() { + val ex = assertFailsWith { + StreamingGgufParametersLoader.validatePolicy(DTypePolicy.Require(FP16)) + } + assertTrue(ex.message?.contains("Require(FP16)") == true, ex.message ?: "") + } + + @Test + fun require_unsupported_target_fails_fast() { + val ex = assertFailsWith { + StreamingGgufParametersLoader.validatePolicy(DTypePolicy.Require(Int8)) + } + val msg = ex.message ?: "" + assertTrue(msg.contains("Require(Int8)"), msg) + assertTrue(msg.contains("does not cast"), msg) + } + + @Test + fun prefer_and_oneOf_always_pass_validation() { + // Soft policies fall through silently in the loader, so the + // validator must let them all through regardless of target. + StreamingGgufParametersLoader.validatePolicy(DTypePolicy.Prefer(BF16)) + StreamingGgufParametersLoader.validatePolicy(DTypePolicy.Prefer(FP16)) + StreamingGgufParametersLoader.validatePolicy(DTypePolicy.Prefer(Int8)) + StreamingGgufParametersLoader.validatePolicy(DTypePolicy.OneOf(setOf(FP32, BF16))) + StreamingGgufParametersLoader.validatePolicy(DTypePolicy.OneOf(setOf(FP16))) + } +} diff --git a/skainet-io/skainet-io-safetensors/src/commonMain/kotlin/sk/ainet/io/safetensors/Bf16LoadPolicy.kt b/skainet-io/skainet-io-safetensors/src/commonMain/kotlin/sk/ainet/io/safetensors/Bf16LoadPolicy.kt index 192bf9ab..016a9669 100644 --- a/skainet-io/skainet-io-safetensors/src/commonMain/kotlin/sk/ainet/io/safetensors/Bf16LoadPolicy.kt +++ b/skainet-io/skainet-io-safetensors/src/commonMain/kotlin/sk/ainet/io/safetensors/Bf16LoadPolicy.kt @@ -1,5 +1,9 @@ package sk.ainet.io.safetensors +import sk.ainet.lang.types.BF16 +import sk.ainet.lang.types.DTypePolicy +import sk.ainet.lang.types.FP32 + /** * Controls how the SafeTensors loader handles `BFLOAT16` (BF16) tensors. * @@ -53,4 +57,23 @@ public enum class Bf16LoadPolicy { * transformer case). */ KEEP_NATIVE, + ; + + /** + * Maps this BF16-specific enum onto the generalised + * [DTypePolicy] sealed type. [DEQUANT_TO_FP32] becomes + * `Require(FP32)` (the loader must hand consumers an FP32 + * tensor); [KEEP_NATIVE] becomes `Require(BF16)` (consumers + * dispatch on the native BF16 dtype). + * + * Bridge for the RFC's policy-driven loader work + * (`rfc.md`, issue #615): existing call sites keep using this + * enum verbatim while new code paths can flow through + * [DTypePolicy] uniformly. The two are equivalent for BF16 — + * this method is the explicit equivalence proof. + */ + public fun toDTypePolicy(): DTypePolicy = when (this) { + DEQUANT_TO_FP32 -> DTypePolicy.Require(FP32) + KEEP_NATIVE -> DTypePolicy.Require(BF16) + } } diff --git a/skainet-io/skainet-io-safetensors/src/commonMain/kotlin/sk/ainet/io/safetensors/SafeTensorsParametersLoader.kt b/skainet-io/skainet-io-safetensors/src/commonMain/kotlin/sk/ainet/io/safetensors/SafeTensorsParametersLoader.kt index 4fd0b81c..2d54ed93 100644 --- a/skainet-io/skainet-io-safetensors/src/commonMain/kotlin/sk/ainet/io/safetensors/SafeTensorsParametersLoader.kt +++ b/skainet-io/skainet-io-safetensors/src/commonMain/kotlin/sk/ainet/io/safetensors/SafeTensorsParametersLoader.kt @@ -8,7 +8,10 @@ import sk.ainet.lang.tensor.Shape import sk.ainet.lang.tensor.Tensor import sk.ainet.lang.tensor.data.Bf16DenseTensorData import sk.ainet.lang.tensor.data.TensorData +import sk.ainet.lang.types.BF16 import sk.ainet.lang.types.DType +import sk.ainet.lang.types.DTypePolicy +import sk.ainet.lang.types.FP16 import sk.ainet.lang.types.FP32 import sk.ainet.lang.types.Int32 import sk.ainet.lang.types.Int8 @@ -286,4 +289,64 @@ class SafeTensorsParametersLoader( } } } + + companion object { + + /** + * Constructs a SafeTensorsParametersLoader from a generalised + * [DTypePolicy] instead of the BF16-specific [Bf16LoadPolicy]. + * Bridge for the policy-driven loader path described in the + * dtype-policy RFC (#615). + * + * Policy → behaviour mapping (BF16 source tensors only — + * other dtypes are handled per the per-arm `require` checks + * in [load]): + * - [DTypePolicy.Any]: BF16 dequants to FP32 (the existing + * adaptive default). + * - [DTypePolicy.Require] target = `BF16`: KEEP_NATIVE. + * - [DTypePolicy.Require] target = `FP32`: DEQUANT_TO_FP32. + * - [DTypePolicy.Require] target = `FP16`: throws — F16 + * KEEP_NATIVE is a follow-up (no `Fp16DenseTensorData` + * yet); use `Require(FP32)` if you want F16 dequanted, or + * `Any` to inherit the adaptive default. + * - [DTypePolicy.Require] target = anything else: throws — + * SafeTensors can't fabricate dtypes the file doesn't carry. + * - [DTypePolicy.Prefer] target = `BF16`: KEEP_NATIVE. + * - [DTypePolicy.Prefer] target = anything else: DEQUANT_TO_FP32 + * (the soft path falls through). + * - [DTypePolicy.OneOf] containing `BF16`: KEEP_NATIVE. + * - [DTypePolicy.OneOf] without `BF16`: DEQUANT_TO_FP32. + */ + fun withPolicy( + sourceProvider: () -> RandomAccessSource, + policy: DTypePolicy, + onProgress: (current: Long, total: Long, message: String?) -> Unit = { _, _, _ -> }, + ): SafeTensorsParametersLoader = SafeTensorsParametersLoader( + sourceProvider = sourceProvider, + onProgress = onProgress, + bf16Policy = mapPolicyToBf16(policy), + ) + + internal fun mapPolicyToBf16(policy: DTypePolicy): Bf16LoadPolicy = when (policy) { + DTypePolicy.Any -> Bf16LoadPolicy.DEQUANT_TO_FP32 + is DTypePolicy.Require -> when (policy.target) { + BF16 -> Bf16LoadPolicy.KEEP_NATIVE + FP32 -> Bf16LoadPolicy.DEQUANT_TO_FP32 + FP16 -> throw IllegalArgumentException( + "SafeTensorsParametersLoader: Require(FP16) is not supported — " + + "F16 KEEP_NATIVE has no Fp16DenseTensorData backing yet. " + + "Use Require(FP32) to dequant F16 sources, or Any to inherit the adaptive default.", + ) + else -> throw IllegalArgumentException( + "SafeTensorsParametersLoader: Require(${policy.target.name}) is not satisfiable — " + + "the loader produces FP32 / BF16 / Int32 / Int8 tensors depending on source dtype; " + + "it cannot fabricate ${policy.target.name} from arbitrary sources.", + ) + } + is DTypePolicy.Prefer -> if (policy.target == BF16) Bf16LoadPolicy.KEEP_NATIVE + else Bf16LoadPolicy.DEQUANT_TO_FP32 + is DTypePolicy.OneOf -> if (BF16 in policy.allowed) Bf16LoadPolicy.KEEP_NATIVE + else Bf16LoadPolicy.DEQUANT_TO_FP32 + } + } } diff --git a/skainet-io/skainet-io-safetensors/src/commonTest/kotlin/sk/ainet/io/safetensors/Bf16LoadPolicyToDTypePolicyTest.kt b/skainet-io/skainet-io-safetensors/src/commonTest/kotlin/sk/ainet/io/safetensors/Bf16LoadPolicyToDTypePolicyTest.kt new file mode 100644 index 00000000..4782c0c5 --- /dev/null +++ b/skainet-io/skainet-io-safetensors/src/commonTest/kotlin/sk/ainet/io/safetensors/Bf16LoadPolicyToDTypePolicyTest.kt @@ -0,0 +1,41 @@ +package sk.ainet.io.safetensors + +import kotlin.test.Test +import kotlin.test.assertEquals +import sk.ainet.lang.types.BF16 +import sk.ainet.lang.types.DTypePolicy +import sk.ainet.lang.types.FP32 + +/** + * Verifies the [Bf16LoadPolicy.toDTypePolicy] adapter — the bridge + * between the BF16-specific enum (existing prior art) and the + * generalised [DTypePolicy] sealed type (W1 of #615). Confirms + * both arms of the enum land on equivalent `Require` policies so + * downstream code paths can flow through `DTypePolicy` uniformly. + */ +class Bf16LoadPolicyToDTypePolicyTest { + + @Test + fun dequant_to_fp32_maps_to_require_fp32() { + val policy = Bf16LoadPolicy.DEQUANT_TO_FP32.toDTypePolicy() + assertEquals(DTypePolicy.Require(FP32), policy) + } + + @Test + fun keep_native_maps_to_require_bf16() { + val policy = Bf16LoadPolicy.KEEP_NATIVE.toDTypePolicy() + assertEquals(DTypePolicy.Require(BF16), policy) + } + + @Test + fun adapter_covers_every_enum_arm() { + // Defensive: if a new arm is added to Bf16LoadPolicy without + // also updating toDTypePolicy, this test surfaces it because + // toDTypePolicy's `when` is exhaustive — Kotlin emits a + // compile error rather than silently dropping the case. + for (arm in Bf16LoadPolicy.entries) { + val mapped = arm.toDTypePolicy() + assertEquals(true, mapped is DTypePolicy.Require, "$arm must map to Require, got $mapped") + } + } +} diff --git a/skainet-io/skainet-io-safetensors/src/commonTest/kotlin/sk/ainet/io/safetensors/SafeTensorsParametersLoaderPolicyTest.kt b/skainet-io/skainet-io-safetensors/src/commonTest/kotlin/sk/ainet/io/safetensors/SafeTensorsParametersLoaderPolicyTest.kt new file mode 100644 index 00000000..7296508d --- /dev/null +++ b/skainet-io/skainet-io-safetensors/src/commonTest/kotlin/sk/ainet/io/safetensors/SafeTensorsParametersLoaderPolicyTest.kt @@ -0,0 +1,114 @@ +package sk.ainet.io.safetensors + +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import sk.ainet.lang.types.BF16 +import sk.ainet.lang.types.DTypePolicy +import sk.ainet.lang.types.FP16 +import sk.ainet.lang.types.FP32 +import sk.ainet.lang.types.Int8 + +/** + * Unit tests for the `DTypePolicy` → `Bf16LoadPolicy` adapter in + * [SafeTensorsParametersLoader.mapPolicyToBf16]. The `withPolicy` + * factory is a thin wrapper over this mapper plus the existing + * constructor; testing the mapper covers the routing logic without + * needing a real SafeTensors fixture. + */ +class SafeTensorsParametersLoaderPolicyTest { + + @Test + fun any_maps_to_dequant_to_fp32() { + assertEquals( + Bf16LoadPolicy.DEQUANT_TO_FP32, + SafeTensorsParametersLoader.mapPolicyToBf16(DTypePolicy.Any), + ) + } + + @Test + fun require_bf16_maps_to_keep_native() { + assertEquals( + Bf16LoadPolicy.KEEP_NATIVE, + SafeTensorsParametersLoader.mapPolicyToBf16(DTypePolicy.Require(BF16)), + ) + } + + @Test + fun require_fp32_maps_to_dequant() { + assertEquals( + Bf16LoadPolicy.DEQUANT_TO_FP32, + SafeTensorsParametersLoader.mapPolicyToBf16(DTypePolicy.Require(FP32)), + ) + } + + @Test + fun require_fp16_fails_with_explicit_message() { + val ex = assertFailsWith { + SafeTensorsParametersLoader.mapPolicyToBf16(DTypePolicy.Require(FP16)) + } + // The error message must point the operator at the alternative — + // RFC says "fail-fast with clear diagnostics," not just throw. + val msg = ex.message ?: "" + assertEquals(true, msg.contains("Require(FP16)"), "msg: $msg") + assertEquals(true, msg.contains("Fp16DenseTensorData"), "msg: $msg") + } + + @Test + fun require_unsupported_target_fails_with_explicit_message() { + val ex = assertFailsWith { + SafeTensorsParametersLoader.mapPolicyToBf16(DTypePolicy.Require(Int8)) + } + val msg = ex.message ?: "" + assertEquals(true, msg.contains("Require(Int8)"), "msg: $msg") + assertEquals(true, msg.contains("cannot fabricate"), "msg: $msg") + } + + @Test + fun prefer_bf16_maps_to_keep_native() { + assertEquals( + Bf16LoadPolicy.KEEP_NATIVE, + SafeTensorsParametersLoader.mapPolicyToBf16(DTypePolicy.Prefer(BF16)), + ) + } + + @Test + fun prefer_fp32_or_anything_else_maps_to_dequant() { + assertEquals( + Bf16LoadPolicy.DEQUANT_TO_FP32, + SafeTensorsParametersLoader.mapPolicyToBf16(DTypePolicy.Prefer(FP32)), + ) + assertEquals( + Bf16LoadPolicy.DEQUANT_TO_FP32, + SafeTensorsParametersLoader.mapPolicyToBf16(DTypePolicy.Prefer(FP16)), + "Prefer is soft — unsatisfiable preferences fall through silently, no throw", + ) + } + + @Test + fun oneOf_with_bf16_maps_to_keep_native() { + assertEquals( + Bf16LoadPolicy.KEEP_NATIVE, + SafeTensorsParametersLoader.mapPolicyToBf16(DTypePolicy.OneOf(setOf(BF16, FP32))), + ) + } + + @Test + fun oneOf_without_bf16_maps_to_dequant() { + assertEquals( + Bf16LoadPolicy.DEQUANT_TO_FP32, + SafeTensorsParametersLoader.mapPolicyToBf16(DTypePolicy.OneOf(setOf(FP32, FP16))), + ) + } + + @Test + fun parity_with_bf16LoadPolicy_toDTypePolicy() { + // Round-trip property: the BF16 enum's adapter should land on a + // policy that the inverse mapper sends back to the original enum. + for (arm in Bf16LoadPolicy.entries) { + val asDTypePolicy = arm.toDTypePolicy() + val back = SafeTensorsParametersLoader.mapPolicyToBf16(asDTypePolicy) + assertEquals(arm, back, "round-trip failed for $arm via $asDTypePolicy") + } + } +} diff --git a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/types/DTypePolicy.kt b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/types/DTypePolicy.kt new file mode 100644 index 00000000..8728e283 --- /dev/null +++ b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/types/DTypePolicy.kt @@ -0,0 +1,100 @@ +package sk.ainet.lang.types + +/** + * Execution-side dtype constraint. + * + * A [DTypePolicy] describes what an op, layer, tensor binding, or + * backend *requires* of an input tensor's dtype — it does NOT + * describe what dtype the source file already contains. The loader + * (or the constraint-resolution pass on a compiled graph) is + * responsible for satisfying the policy before forward execution + * begins: by passing the tensor through unchanged, by casting it, + * or by failing fast if the requirement can't be met. + * + * Maps directly onto the RFC's "policy categories" section + * (`rfc.md`, "DType Constraints as Policies"). The four arms + * cover the full spectrum of strictness: + * + * - [Any]: no constraint — keep the source dtype, whatever it is. + * The adaptive default; this is what every existing call site + * gets implicitly today. + * - [Require]: hard requirement — fail fast at load / compile if + * the tensor can't be made available in the required dtype. + * - [Prefer]: soft requirement — use the preferred dtype if it's + * already available or cheap to produce, otherwise warn and fall + * through. + * - [OneOf]: restricted set — accept any dtype from a small list, + * convert from outside the set if a conversion exists. + * + * Prior art in the codebase: `Bf16LoadPolicy` in + * `skainet-io-safetensors` (the `DEQUANT_TO_FP32 | KEEP_NATIVE` + * enum) is exactly this pattern, scoped to one dtype. [DTypePolicy] + * generalises it so the same shape applies to every dtype the + * engine supports. + */ +public sealed interface DTypePolicy { + /** + * Returns `true` if a tensor that currently has dtype [candidate] + * already satisfies this policy without conversion. Resolution + * code uses this as the fast-path check: if it returns `true`, + * no cast is needed; otherwise the resolver decides whether to + * cast, warn, or fail per the policy arm. + */ + public fun isSatisfiedBy(candidate: DType): Boolean + + /** Adaptive: no dtype constraint. */ + public data object Any : DTypePolicy { + override fun isSatisfiedBy(candidate: DType): Boolean = true + } + + /** + * Hard requirement: the tensor MUST be available in [target]. + * If the source dtype doesn't match and no cast kernel is + * registered to bridge `source → target`, the loader / pass + * raises an error before forward execution can start. + */ + public data class Require(val target: DType) : DTypePolicy { + override fun isSatisfiedBy(candidate: DType): Boolean = + candidate == target + } + + /** + * Soft preference: use [target] if already available or cheap + * to produce, otherwise fall through to the source dtype with + * a warning. + */ + public data class Prefer(val target: DType) : DTypePolicy { + override fun isSatisfiedBy(candidate: DType): Boolean = + candidate == target + } + + /** + * Restricted set: any dtype in [allowed] passes verbatim; + * anything outside the set is a candidate for conversion. + */ + public data class OneOf(val allowed: Set) : DTypePolicy { + init { + require(allowed.isNotEmpty()) { + "DTypePolicy.OneOf requires a non-empty allowed set" + } + } + + override fun isSatisfiedBy(candidate: DType): Boolean = + candidate in allowed + } + + public companion object { + /** Java-friendly factory for [Any]. */ + @kotlin.jvm.JvmStatic public fun any(): DTypePolicy = Any + + /** Java-friendly factory for [Require]. */ + @kotlin.jvm.JvmStatic public fun require(target: DType): DTypePolicy = Require(target) + + /** Java-friendly factory for [Prefer]. */ + @kotlin.jvm.JvmStatic public fun prefer(target: DType): DTypePolicy = Prefer(target) + + /** Java-friendly factory for [OneOf]. */ + @kotlin.jvm.JvmStatic public fun oneOf(vararg allowed: DType): DTypePolicy = + OneOf(allowed.toSet()) + } +} diff --git a/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/types/DTypePolicyTest.kt b/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/types/DTypePolicyTest.kt new file mode 100644 index 00000000..82bd907c --- /dev/null +++ b/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/types/DTypePolicyTest.kt @@ -0,0 +1,73 @@ +package sk.ainet.lang.types + +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertFalse +import kotlin.test.assertTrue + +class DTypePolicyTest { + + @Test + fun any_isSatisfiedBy_every_dtype() { + for ((_, dtype) in DType.getAllTypes()) { + assertTrue( + DTypePolicy.Any.isSatisfiedBy(dtype), + "DTypePolicy.Any must accept every dtype; rejected $dtype", + ) + } + } + + @Test + fun require_isSatisfiedBy_only_target_dtype() { + val policy = DTypePolicy.Require(FP32) + assertTrue(policy.isSatisfiedBy(FP32), "Require(FP32) must accept FP32") + assertFalse(policy.isSatisfiedBy(BF16), "Require(FP32) must reject BF16") + assertFalse(policy.isSatisfiedBy(Int8), "Require(FP32) must reject Int8") + } + + @Test + fun prefer_isSatisfiedBy_only_target_dtype() { + // Prefer has the same satisfied-by predicate as Require — the + // difference is in resolution behavior, not in pass-through + // detection. + val policy = DTypePolicy.Prefer(BF16) + assertTrue(policy.isSatisfiedBy(BF16)) + assertFalse(policy.isSatisfiedBy(FP32)) + } + + @Test + fun oneOf_isSatisfiedBy_any_member_of_set() { + val policy = DTypePolicy.OneOf(setOf(FP32, BF16, FP16)) + assertTrue(policy.isSatisfiedBy(FP32)) + assertTrue(policy.isSatisfiedBy(BF16)) + assertTrue(policy.isSatisfiedBy(FP16)) + assertFalse(policy.isSatisfiedBy(Int8)) + assertFalse(policy.isSatisfiedBy(FP64)) + } + + @Test + fun oneOf_rejects_empty_set() { + assertFailsWith { + DTypePolicy.OneOf(emptySet()) + } + } + + @Test + fun data_class_equality() { + assertEquals(DTypePolicy.Require(FP32), DTypePolicy.Require(FP32)) + assertEquals(DTypePolicy.Prefer(BF16), DTypePolicy.Prefer(BF16)) + assertEquals( + DTypePolicy.OneOf(setOf(FP32, BF16)), + DTypePolicy.OneOf(setOf(BF16, FP32)), + ) + } + + @Test + fun java_factories_match_kotlin_constructors() { + assertEquals(DTypePolicy.Any, DTypePolicy.any()) + assertEquals(DTypePolicy.Require(FP32), DTypePolicy.require(FP32)) + assertEquals(DTypePolicy.Prefer(BF16), DTypePolicy.prefer(BF16)) + assertEquals(DTypePolicy.OneOf(setOf(FP32, BF16)), DTypePolicy.oneOf(FP32, BF16)) + } +} diff --git a/skainet-lang/skainet-lang-dag/src/commonMain/kotlin/sk/ainet/lang/dag/DtypePolicyDsl.kt b/skainet-lang/skainet-lang-dag/src/commonMain/kotlin/sk/ainet/lang/dag/DtypePolicyDsl.kt new file mode 100644 index 00000000..3e047718 --- /dev/null +++ b/skainet-lang/skainet-lang-dag/src/commonMain/kotlin/sk/ainet/lang/dag/DtypePolicyDsl.kt @@ -0,0 +1,64 @@ +package sk.ainet.lang.dag + +import sk.ainet.lang.tensor.ops.Operation +import sk.ainet.lang.types.DTypePolicy + +/** + * Attribute key under which [DTypePolicy] is stored on a + * [GraphNodeDefinition]. The constraint-resolution pass + * (`DTypeConstraintResolutionPass`, W7 of #615) reads this key from + * each node's [GraphNodeDefinition.attributes] map. + * + * Lives in `skainet-lang-dag` so both the DSL (this file) and the + * compile-side pass (in `skainet-compile-opt`) agree on the + * convention without either side importing the other. + */ +public const val DTYPE_POLICY_ATTRIBUTE_KEY: String = "dtype_policy" + +/** + * DSL extension on [DagBuilder] that records a graph op with an + * attached [DTypePolicy]. Wraps the existing + * [DagBuilder.op] entry point — the policy lands in the node's + * [GraphNodeDefinition.attributes] under [DTYPE_POLICY_ATTRIBUTE_KEY]. + * + * Usage: + * ```kotlin + * val mm = op( + * operation = matmul, + * inputs = listOf(input, weight), + * dtypePolicy = DTypePolicy.Require(BF16), + * ) + * ``` + * + * Equivalent (but lossier — no constant key, no type help) form + * with the base `op(...)` builder: + * ```kotlin + * op(matmul, listOf(input, weight), + * attributes = mapOf(DTYPE_POLICY_ATTRIBUTE_KEY to DTypePolicy.Require(BF16))) + * ``` + * + * The DSL extension is preferred — typed, discoverable, and + * survives renames cleanly via the constant. + */ +@DagDsl +public fun DagBuilder.op( + operation: Operation, + inputs: List>, + dtypePolicy: DTypePolicy, + id: String = "", + extraAttributes: Map = emptyMap(), +): List> = op( + operation = operation, + inputs = inputs, + id = id, + attributes = extraAttributes + (DTYPE_POLICY_ATTRIBUTE_KEY to dtypePolicy), +) + +/** + * Convenience accessor: extracts the [DTypePolicy] previously + * attached to [node]'s defining-graph-node via [op]. Returns `null` + * if no policy was attached or if the stored value isn't a + * [DTypePolicy] (defensive — `attributes` is `Map`). + */ +public fun GraphNodeDefinition.dtypePolicy(): DTypePolicy? = + attributes[DTYPE_POLICY_ATTRIBUTE_KEY] as? DTypePolicy diff --git a/skainet-lang/skainet-lang-dag/src/commonTest/kotlin/sk/ainet/lang/dag/DtypePolicyDslTest.kt b/skainet-lang/skainet-lang-dag/src/commonTest/kotlin/sk/ainet/lang/dag/DtypePolicyDslTest.kt new file mode 100644 index 00000000..fcb03491 --- /dev/null +++ b/skainet-lang/skainet-lang-dag/src/commonTest/kotlin/sk/ainet/lang/dag/DtypePolicyDslTest.kt @@ -0,0 +1,63 @@ +package sk.ainet.lang.dag + +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertNotNull +import kotlin.test.assertNull +import sk.ainet.lang.tensor.ops.MatmulOperation +import sk.ainet.lang.tensor.ops.TensorSpec +import sk.ainet.lang.types.BF16 +import sk.ainet.lang.types.DTypePolicy +import sk.ainet.lang.types.FP32 + +class DtypePolicyDslTest { + + @Test + fun op_with_dtypePolicy_records_policy_under_known_attribute_key() { + val program = dag { + val x = input("x", TensorSpec("x", listOf(1, 4), "Float32")) + val w = parameter("w") { shape(4, 4) { ones() } } + op( + operation = MatmulOperation(), + inputs = listOf(x, w), + dtypePolicy = DTypePolicy.Require(BF16), + ) + } + + val mmNode = program.nodes.last { it.operation is MatmulOperation<*, *> } + val policy = mmNode.dtypePolicy() + assertNotNull(policy, "node must carry the DTypePolicy attribute") + assertEquals(DTypePolicy.Require(BF16), policy) + // The attribute lands under the shared constant key so the + // constraint-resolution pass can find it. + assertEquals(policy, mmNode.attributes[DTYPE_POLICY_ATTRIBUTE_KEY]) + } + + @Test + fun nodes_without_dtypePolicy_return_null_from_accessor() { + val program = dag { + val x = input("x", TensorSpec("x", listOf(1, 4), "Float32")) + val w = parameter("w") { shape(4, 4) { ones() } } + op(MatmulOperation(), listOf(x, w)) // no dtypePolicy + } + val mmNode = program.nodes.last { it.operation is MatmulOperation<*, *> } + assertNull(mmNode.dtypePolicy(), "absent policy must return null, not throw") + } + + @Test + fun extraAttributes_preserved_alongside_dtypePolicy() { + val program = dag { + val x = input("x", TensorSpec("x", listOf(1, 4), "Float32")) + val w = parameter("w") { shape(4, 4) { ones() } } + op( + operation = MatmulOperation(), + inputs = listOf(x, w), + dtypePolicy = DTypePolicy.Prefer(BF16), + extraAttributes = mapOf("note" to "attention projection"), + ) + } + val mmNode = program.nodes.last { it.operation is MatmulOperation<*, *> } + assertEquals(DTypePolicy.Prefer(BF16), mmNode.dtypePolicy()) + assertEquals("attention projection", mmNode.attributes["note"]) + } +}