diff --git a/skainet-compile/skainet-compile-minerva/api/skainet-compile-minerva.api b/skainet-compile/skainet-compile-minerva/api/skainet-compile-minerva.api index 4fd4cd45..2cc6e05d 100644 --- a/skainet-compile/skainet-compile-minerva/api/skainet-compile-minerva.api +++ b/skainet-compile/skainet-compile-minerva/api/skainet-compile-minerva.api @@ -124,13 +124,15 @@ public final class sk/ainet/compile/minerva/MinervaExportFacade { public fun (Ljava/lang/String;)V public fun (Ljava/lang/String;Lsk/ainet/compile/minerva/MinervaCompatibilityValidator;)V public fun (Ljava/lang/String;Lsk/ainet/compile/minerva/MinervaCompatibilityValidator;Lsk/ainet/compile/minerva/MinervaGraphCanonicalizer;)V - public synthetic fun (Ljava/lang/String;Lsk/ainet/compile/minerva/MinervaCompatibilityValidator;Lsk/ainet/compile/minerva/MinervaGraphCanonicalizer;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public fun (Ljava/lang/String;Lsk/ainet/compile/minerva/MinervaCompatibilityValidator;Lsk/ainet/compile/minerva/MinervaGraphCanonicalizer;Lsk/ainet/compile/minerva/MinervaNpzModelWriter;)V + public synthetic fun (Ljava/lang/String;Lsk/ainet/compile/minerva/MinervaCompatibilityValidator;Lsk/ainet/compile/minerva/MinervaGraphCanonicalizer;Lsk/ainet/compile/minerva/MinervaNpzModelWriter;ILkotlin/jvm/internal/DefaultConstructorMarker;)V public final fun exportGraph (Lsk/ainet/lang/graph/ComputeGraph;Lsk/ainet/compile/minerva/MinervaExportOptions;)Lsk/ainet/compile/minerva/MinervaExportResult; public final fun exportModel (Ljava/lang/Object;Lkotlin/jvm/functions/Function1;Lsk/ainet/compile/minerva/MinervaExportOptions;)Lsk/ainet/compile/minerva/MinervaExportResult; public final fun exportModel (Ljava/lang/Object;Lsk/ainet/compile/minerva/MinervaExportOptions;)Lsk/ainet/compile/minerva/MinervaExportResult; public final fun getBackendName ()Ljava/lang/String; public final fun getCompatibilityValidator ()Lsk/ainet/compile/minerva/MinervaCompatibilityValidator; public final fun getGraphCanonicalizer ()Lsk/ainet/compile/minerva/MinervaGraphCanonicalizer; + public final fun getNpzWriter ()Lsk/ainet/compile/minerva/MinervaNpzModelWriter; } public final class sk/ainet/compile/minerva/MinervaExportFailure { @@ -158,6 +160,7 @@ public final class sk/ainet/compile/minerva/MinervaExportFailureKind : java/lang public static final field GRAPH_VALIDATION_FAILED Lsk/ainet/compile/minerva/MinervaExportFailureKind; public static final field LOWERING_FAILED Lsk/ainet/compile/minerva/MinervaExportFailureKind; public static final field NOT_IMPLEMENTED Lsk/ainet/compile/minerva/MinervaExportFailureKind; + public static final field NPZ_SCHEMA_FAILED Lsk/ainet/compile/minerva/MinervaExportFailureKind; public static final field RECORDING_FAILED Lsk/ainet/compile/minerva/MinervaExportFailureKind; public static final field UNSUPPORTED_MODEL_TYPE Lsk/ainet/compile/minerva/MinervaExportFailureKind; public static fun getEntries ()Lkotlin/enums/EnumEntries; @@ -203,9 +206,10 @@ public final class sk/ainet/compile/minerva/MinervaExportOptions { } public final class sk/ainet/compile/minerva/MinervaExportResult { - public fun (Lsk/ainet/compile/minerva/MinervaExportOptions;Lsk/ainet/compile/export/GraphExportStatus;Lsk/ainet/compile/minerva/MinervaExportBundle;Lsk/ainet/compile/export/GraphExportDiagnosticReport;Ljava/util/List;Lsk/ainet/compile/minerva/MinervaExportFailure;Ljava/util/Map;Lsk/ainet/compile/minerva/MinervaCompatibilityReport;Lsk/ainet/compile/minerva/MinervaIntermediate;)V - public synthetic fun (Lsk/ainet/compile/minerva/MinervaExportOptions;Lsk/ainet/compile/export/GraphExportStatus;Lsk/ainet/compile/minerva/MinervaExportBundle;Lsk/ainet/compile/export/GraphExportDiagnosticReport;Ljava/util/List;Lsk/ainet/compile/minerva/MinervaExportFailure;Ljava/util/Map;Lsk/ainet/compile/minerva/MinervaCompatibilityReport;Lsk/ainet/compile/minerva/MinervaIntermediate;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public fun (Lsk/ainet/compile/minerva/MinervaExportOptions;Lsk/ainet/compile/export/GraphExportStatus;Lsk/ainet/compile/minerva/MinervaExportBundle;Lsk/ainet/compile/export/GraphExportDiagnosticReport;Ljava/util/List;Lsk/ainet/compile/minerva/MinervaExportFailure;Ljava/util/Map;Lsk/ainet/compile/minerva/MinervaCompatibilityReport;Lsk/ainet/compile/minerva/MinervaIntermediate;Lsk/ainet/compile/minerva/MinervaNpzModel;)V + public synthetic fun (Lsk/ainet/compile/minerva/MinervaExportOptions;Lsk/ainet/compile/export/GraphExportStatus;Lsk/ainet/compile/minerva/MinervaExportBundle;Lsk/ainet/compile/export/GraphExportDiagnosticReport;Ljava/util/List;Lsk/ainet/compile/minerva/MinervaExportFailure;Ljava/util/Map;Lsk/ainet/compile/minerva/MinervaCompatibilityReport;Lsk/ainet/compile/minerva/MinervaIntermediate;Lsk/ainet/compile/minerva/MinervaNpzModel;ILkotlin/jvm/internal/DefaultConstructorMarker;)V public final fun component1 ()Lsk/ainet/compile/minerva/MinervaExportOptions; + public final fun component10 ()Lsk/ainet/compile/minerva/MinervaNpzModel; public final fun component2 ()Lsk/ainet/compile/export/GraphExportStatus; public final fun component3 ()Lsk/ainet/compile/minerva/MinervaExportBundle; public final fun component4 ()Lsk/ainet/compile/export/GraphExportDiagnosticReport; @@ -214,8 +218,8 @@ public final class sk/ainet/compile/minerva/MinervaExportResult { public final fun component7 ()Ljava/util/Map; public final fun component8 ()Lsk/ainet/compile/minerva/MinervaCompatibilityReport; public final fun component9 ()Lsk/ainet/compile/minerva/MinervaIntermediate; - public final fun copy (Lsk/ainet/compile/minerva/MinervaExportOptions;Lsk/ainet/compile/export/GraphExportStatus;Lsk/ainet/compile/minerva/MinervaExportBundle;Lsk/ainet/compile/export/GraphExportDiagnosticReport;Ljava/util/List;Lsk/ainet/compile/minerva/MinervaExportFailure;Ljava/util/Map;Lsk/ainet/compile/minerva/MinervaCompatibilityReport;Lsk/ainet/compile/minerva/MinervaIntermediate;)Lsk/ainet/compile/minerva/MinervaExportResult; - public static synthetic fun copy$default (Lsk/ainet/compile/minerva/MinervaExportResult;Lsk/ainet/compile/minerva/MinervaExportOptions;Lsk/ainet/compile/export/GraphExportStatus;Lsk/ainet/compile/minerva/MinervaExportBundle;Lsk/ainet/compile/export/GraphExportDiagnosticReport;Ljava/util/List;Lsk/ainet/compile/minerva/MinervaExportFailure;Ljava/util/Map;Lsk/ainet/compile/minerva/MinervaCompatibilityReport;Lsk/ainet/compile/minerva/MinervaIntermediate;ILjava/lang/Object;)Lsk/ainet/compile/minerva/MinervaExportResult; + public final fun copy (Lsk/ainet/compile/minerva/MinervaExportOptions;Lsk/ainet/compile/export/GraphExportStatus;Lsk/ainet/compile/minerva/MinervaExportBundle;Lsk/ainet/compile/export/GraphExportDiagnosticReport;Ljava/util/List;Lsk/ainet/compile/minerva/MinervaExportFailure;Ljava/util/Map;Lsk/ainet/compile/minerva/MinervaCompatibilityReport;Lsk/ainet/compile/minerva/MinervaIntermediate;Lsk/ainet/compile/minerva/MinervaNpzModel;)Lsk/ainet/compile/minerva/MinervaExportResult; + public static synthetic fun copy$default (Lsk/ainet/compile/minerva/MinervaExportResult;Lsk/ainet/compile/minerva/MinervaExportOptions;Lsk/ainet/compile/export/GraphExportStatus;Lsk/ainet/compile/minerva/MinervaExportBundle;Lsk/ainet/compile/export/GraphExportDiagnosticReport;Ljava/util/List;Lsk/ainet/compile/minerva/MinervaExportFailure;Ljava/util/Map;Lsk/ainet/compile/minerva/MinervaCompatibilityReport;Lsk/ainet/compile/minerva/MinervaIntermediate;Lsk/ainet/compile/minerva/MinervaNpzModel;ILjava/lang/Object;)Lsk/ainet/compile/minerva/MinervaExportResult; public fun equals (Ljava/lang/Object;)Z public final fun getArtifacts ()Ljava/util/List; public final fun getBundle ()Lsk/ainet/compile/minerva/MinervaExportBundle; @@ -225,6 +229,7 @@ public final class sk/ainet/compile/minerva/MinervaExportResult { public final fun getFailure ()Lsk/ainet/compile/minerva/MinervaExportFailure; public final fun getIntermediate ()Lsk/ainet/compile/minerva/MinervaIntermediate; public final fun getMetadata ()Ljava/util/Map; + public final fun getNpzModel ()Lsk/ainet/compile/minerva/MinervaNpzModel; public final fun getOptions ()Lsk/ainet/compile/minerva/MinervaExportOptions; public final fun getStatus ()Lsk/ainet/compile/export/GraphExportStatus; public final fun getSucceeded ()Z @@ -332,6 +337,79 @@ public final class sk/ainet/compile/minerva/MinervaLoweringException : java/lang public final fun getOperationName ()Ljava/lang/String; } +public final class sk/ainet/compile/minerva/MinervaNpzArray { + public fun (Ljava/lang/String;Lsk/ainet/compile/minerva/MinervaNpzDType;Ljava/util/List;Ljava/util/List;Ljava/util/List;)V + public synthetic fun (Ljava/lang/String;Lsk/ainet/compile/minerva/MinervaNpzDType;Ljava/util/List;Ljava/util/List;Ljava/util/List;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public final fun component1 ()Ljava/lang/String; + public final fun component2 ()Lsk/ainet/compile/minerva/MinervaNpzDType; + public final fun component3 ()Ljava/util/List; + public final fun component4 ()Ljava/util/List; + public final fun component5 ()Ljava/util/List; + public final fun copy (Ljava/lang/String;Lsk/ainet/compile/minerva/MinervaNpzDType;Ljava/util/List;Ljava/util/List;Ljava/util/List;)Lsk/ainet/compile/minerva/MinervaNpzArray; + public static synthetic fun copy$default (Lsk/ainet/compile/minerva/MinervaNpzArray;Ljava/lang/String;Lsk/ainet/compile/minerva/MinervaNpzDType;Ljava/util/List;Ljava/util/List;Ljava/util/List;ILjava/lang/Object;)Lsk/ainet/compile/minerva/MinervaNpzArray; + public fun equals (Ljava/lang/Object;)Z + public final fun getDtype ()Lsk/ainet/compile/minerva/MinervaNpzDType; + public final fun getElementCount ()I + public final fun getFloatData ()Ljava/util/List; + public final fun getIntData ()Ljava/util/List; + public final fun getName ()Ljava/lang/String; + public final fun getShape ()Ljava/util/List; + public fun hashCode ()I + public fun toString ()Ljava/lang/String; +} + +public final class sk/ainet/compile/minerva/MinervaNpzDType : java/lang/Enum { + public static final field FLOAT32 Lsk/ainet/compile/minerva/MinervaNpzDType; + public static final field INT32 Lsk/ainet/compile/minerva/MinervaNpzDType; + public static fun getEntries ()Lkotlin/enums/EnumEntries; + public final fun getNumpyDescriptor ()Ljava/lang/String; + public static fun valueOf (Ljava/lang/String;)Lsk/ainet/compile/minerva/MinervaNpzDType; + public static fun values ()[Lsk/ainet/compile/minerva/MinervaNpzDType; +} + +public final class sk/ainet/compile/minerva/MinervaNpzModel { + public fun (Ljava/lang/String;ILjava/util/List;[BLjava/util/Map;)V + public synthetic fun (Ljava/lang/String;ILjava/util/List;[BLjava/util/Map;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public final fun component1 ()Ljava/lang/String; + public final fun component2 ()I + public final fun component3 ()Ljava/util/List; + public final fun component4 ()[B + public final fun component5 ()Ljava/util/Map; + public final fun copy (Ljava/lang/String;ILjava/util/List;[BLjava/util/Map;)Lsk/ainet/compile/minerva/MinervaNpzModel; + public static synthetic fun copy$default (Lsk/ainet/compile/minerva/MinervaNpzModel;Ljava/lang/String;ILjava/util/List;[BLjava/util/Map;ILjava/lang/Object;)Lsk/ainet/compile/minerva/MinervaNpzModel; + public fun equals (Ljava/lang/Object;)Z + public final fun getArrayNames ()Ljava/util/List; + public final fun getArrays ()Ljava/util/List; + public final fun getBytes ()[B + public final fun getLogicalPath ()Ljava/lang/String; + public final fun getMetadata ()Ljava/util/Map; + public final fun getSchemaVersion ()I + public fun hashCode ()I + public fun toString ()Ljava/lang/String; +} + +public final class sk/ainet/compile/minerva/MinervaNpzModelWriter : sk/ainet/compile/export/GraphExportWriter { + public fun ()V + public fun (I)V + public fun (ILjava/lang/String;)V + public fun (ILjava/lang/String;Ljava/lang/String;)V + public synthetic fun (ILjava/lang/String;Ljava/lang/String;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public fun getBackendName ()Ljava/lang/String; + public final fun getLogicalPath ()Ljava/lang/String; + public final fun getSchemaVersion ()I + public synthetic fun write (Ljava/lang/Object;Lsk/ainet/compile/export/GraphExportContext;)Ljava/lang/Object; + public fun write (Lsk/ainet/compile/minerva/MinervaIntermediate;Lsk/ainet/compile/export/GraphExportContext;)Lsk/ainet/compile/minerva/MinervaNpzModel; +} + +public final class sk/ainet/compile/minerva/MinervaNpzSchemaException : java/lang/IllegalArgumentException { + public fun (Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/util/Map;)V + public synthetic fun (Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/util/Map;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public final fun getArrayName ()Ljava/lang/String; + public final fun getCode ()Ljava/lang/String; + public final fun getDetails ()Ljava/util/Map; + public final fun getLayerId ()Ljava/lang/String; +} + public final class sk/ainet/compile/minerva/MinervaQuantization : java/lang/Enum { public static final field Q8 Lsk/ainet/compile/minerva/MinervaQuantization; public final fun getCompilerId ()Ljava/lang/String; @@ -352,17 +430,18 @@ public final class sk/ainet/compile/minerva/MinervaTarget : java/lang/Enum { } public final class sk/ainet/compile/minerva/MinervaTensorRef { - public fun (Ljava/lang/String;Ljava/lang/String;Ljava/util/List;Ljava/lang/String;Lsk/ainet/compile/minerva/MinervaTensorRole;Ljava/lang/String;Ljava/util/Map;)V - public synthetic fun (Ljava/lang/String;Ljava/lang/String;Ljava/util/List;Ljava/lang/String;Lsk/ainet/compile/minerva/MinervaTensorRole;Ljava/lang/String;Ljava/util/Map;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public fun (Ljava/lang/String;Ljava/lang/String;Ljava/util/List;Ljava/lang/String;Lsk/ainet/compile/minerva/MinervaTensorRole;Ljava/lang/String;Ljava/util/List;Ljava/util/Map;)V + public synthetic fun (Ljava/lang/String;Ljava/lang/String;Ljava/util/List;Ljava/lang/String;Lsk/ainet/compile/minerva/MinervaTensorRole;Ljava/lang/String;Ljava/util/List;Ljava/util/Map;ILkotlin/jvm/internal/DefaultConstructorMarker;)V public final fun component1 ()Ljava/lang/String; public final fun component2 ()Ljava/lang/String; public final fun component3 ()Ljava/util/List; public final fun component4 ()Ljava/lang/String; public final fun component5 ()Lsk/ainet/compile/minerva/MinervaTensorRole; public final fun component6 ()Ljava/lang/String; - public final fun component7 ()Ljava/util/Map; - public final fun copy (Ljava/lang/String;Ljava/lang/String;Ljava/util/List;Ljava/lang/String;Lsk/ainet/compile/minerva/MinervaTensorRole;Ljava/lang/String;Ljava/util/Map;)Lsk/ainet/compile/minerva/MinervaTensorRef; - public static synthetic fun copy$default (Lsk/ainet/compile/minerva/MinervaTensorRef;Ljava/lang/String;Ljava/lang/String;Ljava/util/List;Ljava/lang/String;Lsk/ainet/compile/minerva/MinervaTensorRole;Ljava/lang/String;Ljava/util/Map;ILjava/lang/Object;)Lsk/ainet/compile/minerva/MinervaTensorRef; + public final fun component7 ()Ljava/util/List; + public final fun component8 ()Ljava/util/Map; + public final fun copy (Ljava/lang/String;Ljava/lang/String;Ljava/util/List;Ljava/lang/String;Lsk/ainet/compile/minerva/MinervaTensorRole;Ljava/lang/String;Ljava/util/List;Ljava/util/Map;)Lsk/ainet/compile/minerva/MinervaTensorRef; + public static synthetic fun copy$default (Lsk/ainet/compile/minerva/MinervaTensorRef;Ljava/lang/String;Ljava/lang/String;Ljava/util/List;Ljava/lang/String;Lsk/ainet/compile/minerva/MinervaTensorRole;Ljava/lang/String;Ljava/util/List;Ljava/util/Map;ILjava/lang/Object;)Lsk/ainet/compile/minerva/MinervaTensorRef; public fun equals (Ljava/lang/Object;)Z public final fun getDtype ()Ljava/lang/String; public final fun getElementCount ()I @@ -372,6 +451,7 @@ public final class sk/ainet/compile/minerva/MinervaTensorRef { public final fun getRole ()Lsk/ainet/compile/minerva/MinervaTensorRole; public final fun getShape ()Ljava/util/List; public final fun getSourceNodeId ()Ljava/lang/String; + public final fun getValues ()Ljava/util/List; public fun hashCode ()I public fun toString ()Ljava/lang/String; } diff --git a/skainet-compile/skainet-compile-minerva/src/commonMain/kotlin/sk/ainet/compile/minerva/MinervaExportFacade.kt b/skainet-compile/skainet-compile-minerva/src/commonMain/kotlin/sk/ainet/compile/minerva/MinervaExportFacade.kt index 6b1ad930..2009ccb1 100644 --- a/skainet-compile/skainet-compile-minerva/src/commonMain/kotlin/sk/ainet/compile/minerva/MinervaExportFacade.kt +++ b/skainet-compile/skainet-compile-minerva/src/commonMain/kotlin/sk/ainet/compile/minerva/MinervaExportFacade.kt @@ -21,7 +21,8 @@ import sk.ainet.tape.Execution public class MinervaExportFacade @kotlin.jvm.JvmOverloads constructor( public val backendName: String = MinervaExportBackend.backendName, public val compatibilityValidator: MinervaCompatibilityValidator = MinervaCompatibilityValidator(), - public val graphCanonicalizer: MinervaGraphCanonicalizer = MinervaGraphCanonicalizer() + public val graphCanonicalizer: MinervaGraphCanonicalizer = MinervaGraphCanonicalizer(), + public val npzWriter: MinervaNpzModelWriter = MinervaNpzModelWriter() ) { /** @@ -98,17 +99,31 @@ public class MinervaExportFacade @kotlin.jvm.JvmOverloads constructor( return loweringFailedResult(options, context, compatibilityReport, exception) } + val npzModel = try { + npzWriter.write(intermediate, context) + } catch (exception: MinervaNpzSchemaException) { + return npzSchemaFailedResult( + options = options, + context = context, + compatibilityReport = compatibilityReport, + intermediate = intermediate, + exception = exception + ) + } + val failure = MinervaExportFailure( kind = MinervaExportFailureKind.NOT_IMPLEMENTED, - stage = GraphExportStage.WRITING, + stage = GraphExportStage.PACKAGING, code = "minerva.export.not_implemented", - message = "Minerva export lowered the graph to phase-one IR; compiler invocation, packaging, and verification are implemented in follow-up issues.", + message = "Minerva export lowered the graph and emitted the NPZ compiler input; compiler invocation, packaging, and verification are implemented in follow-up issues.", details = mapOf( - "nextStep" to "Invoke the Minerva compiler and write the runtime project.", - "issue" to "#693", + "nextStep" to "Invoke libminerva compiler and package generated outputs.", + "issue" to "#694", "layers" to intermediate.layerCount.toString(), "input" to intermediate.input.id, - "output" to intermediate.output.id + "output" to intermediate.output.id, + "npzPath" to npzModel.logicalPath, + "npzBytes" to npzModel.bytes.size.toString() ) ) context.error( @@ -122,7 +137,8 @@ public class MinervaExportFacade @kotlin.jvm.JvmOverloads constructor( context = context, failure = failure, compatibilityReport = compatibilityReport, - intermediate = intermediate + intermediate = intermediate, + npzModel = npzModel ) } @@ -223,12 +239,49 @@ public class MinervaExportFacade @kotlin.jvm.JvmOverloads constructor( ) } + private fun npzSchemaFailedResult( + options: MinervaExportOptions, + context: GraphExportContext, + compatibilityReport: MinervaCompatibilityReport, + intermediate: MinervaIntermediate, + exception: MinervaNpzSchemaException + ): MinervaExportResult { + val details = mutableMapOf( + "code" to exception.code, + "issue" to "#693" + ) + exception.layerId?.let { details["layerId"] = it } + exception.arrayName?.let { details["arrayName"] = it } + details += exception.details + val failure = MinervaExportFailure( + kind = MinervaExportFailureKind.NPZ_SCHEMA_FAILED, + stage = GraphExportStage.WRITING, + code = exception.code, + message = exception.message ?: "Minerva NPZ schema validation failed.", + details = details + ) + context.error( + stage = failure.stage, + code = failure.code, + message = failure.message, + details = failure.details + ) + return failedResult( + options = options, + context = context, + failure = failure, + compatibilityReport = compatibilityReport, + intermediate = intermediate + ) + } + private fun failedResult( options: MinervaExportOptions, context: GraphExportContext, failure: MinervaExportFailure, compatibilityReport: MinervaCompatibilityReport? = null, - intermediate: MinervaIntermediate? = null + intermediate: MinervaIntermediate? = null, + npzModel: MinervaNpzModel? = null ): MinervaExportResult { return MinervaExportResult( options = options, @@ -238,7 +291,8 @@ public class MinervaExportFacade @kotlin.jvm.JvmOverloads constructor( failure = failure, metadata = context.metadata, compatibilityReport = compatibilityReport, - intermediate = intermediate + intermediate = intermediate, + npzModel = npzModel ) } diff --git a/skainet-compile/skainet-compile-minerva/src/commonMain/kotlin/sk/ainet/compile/minerva/MinervaExportModels.kt b/skainet-compile/skainet-compile-minerva/src/commonMain/kotlin/sk/ainet/compile/minerva/MinervaExportModels.kt index b242b821..8eb723bd 100644 --- a/skainet-compile/skainet-compile-minerva/src/commonMain/kotlin/sk/ainet/compile/minerva/MinervaExportModels.kt +++ b/skainet-compile/skainet-compile-minerva/src/commonMain/kotlin/sk/ainet/compile/minerva/MinervaExportModels.kt @@ -92,6 +92,7 @@ public enum class MinervaExportFailureKind { GRAPH_VALIDATION_FAILED, COMPATIBILITY_VALIDATION_FAILED, LOWERING_FAILED, + NPZ_SCHEMA_FAILED, NOT_IMPLEMENTED } @@ -203,7 +204,8 @@ public data class MinervaExportResult( public val failure: MinervaExportFailure? = null, public val metadata: Map = emptyMap(), public val compatibilityReport: MinervaCompatibilityReport? = null, - public val intermediate: MinervaIntermediate? = null + public val intermediate: MinervaIntermediate? = null, + public val npzModel: MinervaNpzModel? = null ) { init { require(status != GraphExportStatus.SUCCESS || bundle != null) { diff --git a/skainet-compile/skainet-compile-minerva/src/commonMain/kotlin/sk/ainet/compile/minerva/MinervaGraphCanonicalizer.kt b/skainet-compile/skainet-compile-minerva/src/commonMain/kotlin/sk/ainet/compile/minerva/MinervaGraphCanonicalizer.kt index 3bb9fa83..bd24e8a0 100644 --- a/skainet-compile/skainet-compile-minerva/src/commonMain/kotlin/sk/ainet/compile/minerva/MinervaGraphCanonicalizer.kt +++ b/skainet-compile/skainet-compile-minerva/src/commonMain/kotlin/sk/ainet/compile/minerva/MinervaGraphCanonicalizer.kt @@ -337,11 +337,88 @@ public class MinervaGraphCanonicalizer @kotlin.jvm.JvmOverloads constructor( dtype = spec.dtype, role = role, sourceNodeId = sourceNode.id, + values = tensorValues(spec, shape, sourceNode, context), metadata = spec.metadata.mapValues { it.value.toString() } ) } } + private fun tensorValues( + spec: TensorSpec, + shape: List, + sourceNode: GraphNode, + context: GraphExportContext + ): List? { + val elementCount = shape.fold(1) { acc, dim -> acc * dim } + val values = when (val rawValues = spec.metadata["values"]) { + null -> symbolicValues(spec, elementCount) + is FloatArray -> rawValues.toList() + is IntArray -> rawValues.map { it.toFloat() } + is List<*> -> rawValues.map { value -> + when (value) { + is Number -> value.toFloat() + else -> fail( + context = context, + code = "minerva.lowering.tensor_values_invalid", + message = "Tensor '${spec.name}' on node '${sourceNode.id}' has non-numeric initializer data.", + node = sourceNode, + details = mapOf("remediation" to "Use numeric FloatArray or IntArray initializer metadata.") + ) + } + } + else -> fail( + context = context, + code = "minerva.lowering.tensor_values_invalid", + message = "Tensor '${spec.name}' on node '${sourceNode.id}' has unsupported initializer metadata.", + node = sourceNode, + details = mapOf( + "valuesType" to rawValues::class.simpleName.orEmpty(), + "remediation" to "Use numeric FloatArray or IntArray initializer metadata." + ) + ) + } ?: return null + if (values.size != elementCount) { + fail( + context = context, + code = "minerva.lowering.tensor_values_shape_mismatch", + message = "Tensor '${spec.name}' on node '${sourceNode.id}' initializer has ${values.size} value(s), expected $elementCount.", + node = sourceNode, + details = mapOf( + "actual" to values.size.toString(), + "expected" to elementCount.toString(), + "remediation" to "Match initializer data length to the tensor shape." + ) + ) + } + if (values.any { !it.isFinite() }) { + fail( + context = context, + code = "minerva.lowering.tensor_values_non_finite", + message = "Tensor '${spec.name}' on node '${sourceNode.id}' initializer contains non-finite values.", + node = sourceNode, + details = mapOf("remediation" to "Use finite numeric initializer values.") + ) + } + return values + } + + private fun symbolicValues(spec: TensorSpec, elementCount: Int): List? { + return when (val init = spec.metadata["init"]?.toString()) { + "zeros" -> List(elementCount) { 0.0f } + "ones" -> List(elementCount) { 1.0f } + null, "unspecified" -> null + else -> { + if (init.startsWith("full(") && init.endsWith(")")) { + val value = spec.metadata["value"] as? Number + ?: init.removePrefix("full(").removeSuffix(")").toFloatOrNull() + if (value != null) List(elementCount) { value.toFloat() } else null + } else { + null + } + } + } + } + private fun tensorId(role: MinervaTensorRole, sourceNodeId: String, tensorName: String): String { val cleanName = tensorName.replace(Regex("[^A-Za-z0-9_]+"), "_").ifBlank { "tensor" } val cleanNode = sourceNodeId.replace(Regex("[^A-Za-z0-9_]+"), "_").ifBlank { "node" } diff --git a/skainet-compile/skainet-compile-minerva/src/commonMain/kotlin/sk/ainet/compile/minerva/MinervaIntermediateModels.kt b/skainet-compile/skainet-compile-minerva/src/commonMain/kotlin/sk/ainet/compile/minerva/MinervaIntermediateModels.kt index 5855efe5..0d2e7c67 100644 --- a/skainet-compile/skainet-compile-minerva/src/commonMain/kotlin/sk/ainet/compile/minerva/MinervaIntermediateModels.kt +++ b/skainet-compile/skainet-compile-minerva/src/commonMain/kotlin/sk/ainet/compile/minerva/MinervaIntermediateModels.kt @@ -37,6 +37,7 @@ public data class MinervaTensorRef( public val dtype: String, public val role: MinervaTensorRole, public val sourceNodeId: String? = null, + public val values: List? = null, public val metadata: Map = emptyMap() ) { init { @@ -45,6 +46,12 @@ public data class MinervaTensorRef( require(shape.isNotEmpty()) { "tensor shape cannot be empty" } require(shape.all { it > 0 }) { "tensor shape dimensions must be positive" } require(dtype.isNotBlank()) { "tensor dtype cannot be blank" } + require(values == null || values.size == elementCount) { + "tensor values must match tensor element count" + } + require(values == null || values.all { it.isFinite() }) { + "tensor values must be finite" + } } public val elementCount: Int @@ -99,4 +106,3 @@ public data class MinervaIntermediate( public fun requireLowered(): MinervaIntermediate = this } - diff --git a/skainet-compile/skainet-compile-minerva/src/commonMain/kotlin/sk/ainet/compile/minerva/MinervaNpzModelWriter.kt b/skainet-compile/skainet-compile-minerva/src/commonMain/kotlin/sk/ainet/compile/minerva/MinervaNpzModelWriter.kt new file mode 100644 index 00000000..bd359f9b --- /dev/null +++ b/skainet-compile/skainet-compile-minerva/src/commonMain/kotlin/sk/ainet/compile/minerva/MinervaNpzModelWriter.kt @@ -0,0 +1,432 @@ +package sk.ainet.compile.minerva + +import sk.ainet.compile.export.GraphExportArtifact +import sk.ainet.compile.export.GraphExportArtifactRole +import sk.ainet.compile.export.GraphExportContext +import sk.ainet.compile.export.GraphExportStage +import sk.ainet.compile.export.GraphExportWriter + +/** + * Supported NumPy array dtypes emitted by the Minerva NPZ writer. + */ +public enum class MinervaNpzDType(public val numpyDescriptor: String) { + FLOAT32(", + public val floatData: List = emptyList(), + public val intData: List = emptyList() +) { + init { + require(name.isNotBlank()) { "array name cannot be blank" } + require(shape.isNotEmpty()) { "array shape cannot be empty" } + require(shape.all { it >= 0 }) { "array shape dimensions must be non-negative" } + val elementCount = shape.fold(1) { acc, dim -> acc * dim } + when (dtype) { + MinervaNpzDType.FLOAT32 -> { + require(floatData.size == elementCount) { "floatData size must match array element count" } + require(intData.isEmpty()) { "intData must be empty for FLOAT32 arrays" } + require(floatData.all { it.isFinite() }) { "floatData values must be finite" } + } + MinervaNpzDType.INT32 -> { + require(intData.size == elementCount) { "intData size must match array element count" } + require(floatData.isEmpty()) { "floatData must be empty for INT32 arrays" } + } + } + } + + public val elementCount: Int + get() = shape.fold(1) { acc, dim -> acc * dim } +} + +/** + * In-memory Minerva compiler input archive. + */ +public data class MinervaNpzModel( + public val logicalPath: String, + public val schemaVersion: Int, + public val arrays: List, + public val bytes: ByteArray, + public val metadata: Map = emptyMap() +) { + init { + require(logicalPath.isNotBlank()) { "logicalPath cannot be blank" } + require(schemaVersion > 0) { "schemaVersion must be positive" } + require(arrays.isNotEmpty()) { "NPZ model requires arrays" } + require(bytes.isNotEmpty()) { "NPZ model bytes cannot be empty" } + } + + public val arrayNames: List + get() = arrays.map { it.name } + + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (other !is MinervaNpzModel) return false + return logicalPath == other.logicalPath && + schemaVersion == other.schemaVersion && + arrays == other.arrays && + bytes.contentEquals(other.bytes) && + metadata == other.metadata + } + + override fun hashCode(): Int { + var result = logicalPath.hashCode() + result = 31 * result + schemaVersion + result = 31 * result + arrays.hashCode() + result = 31 * result + bytes.contentHashCode() + result = 31 * result + metadata.hashCode() + return result + } +} + +/** + * Typed schema error for malformed Minerva NPZ compiler input. + */ +public class MinervaNpzSchemaException( + message: String, + public val code: String, + public val layerId: String? = null, + public val arrayName: String? = null, + public val details: Map = emptyMap() +) : IllegalArgumentException(message) { + init { + require(code.isNotBlank()) { "schema exception code cannot be blank" } + } +} + +/** + * Emits the Minerva phase-one NPZ schema from a lowered intermediate model. + */ +public class MinervaNpzModelWriter @kotlin.jvm.JvmOverloads constructor( + public val schemaVersion: Int = 1, + public val logicalPath: String = "model.npz", + override val backendName: String = MinervaExportBackend.backendName +) : GraphExportWriter { + + override fun write(intermediate: MinervaIntermediate, context: GraphExportContext): MinervaNpzModel { + context.info( + stage = GraphExportStage.WRITING, + code = "minerva.npz.started", + message = "Writing Minerva NPZ compiler input.", + details = mapOf( + "schemaVersion" to schemaVersion.toString(), + "layers" to intermediate.layerCount.toString() + ) + ) + + val arrays = arraysFor(intermediate) + val bytes = MinervaNpzArchiveWriter.write(arrays) + val metadata = intermediate.metadata + mapOf( + "schemaVersion" to schemaVersion.toString(), + "layerCount" to intermediate.layerCount.toString(), + "inputShape" to intermediate.input.shape.joinToString("x"), + "outputShape" to intermediate.output.shape.joinToString("x"), + "format" to "npz" + ) + val model = MinervaNpzModel( + logicalPath = logicalPath, + schemaVersion = schemaVersion, + arrays = arrays, + bytes = bytes, + metadata = metadata + ) + + context.addArtifact( + GraphExportArtifact( + path = logicalPath, + role = GraphExportArtifactRole.INTERMEDIATE, + description = "Minerva model NPZ compiler input", + metadata = mapOf( + "schemaVersion" to schemaVersion.toString(), + "layers" to intermediate.layerCount.toString(), + "bytes" to bytes.size.toString() + ) + ) + ) + context.info( + stage = GraphExportStage.WRITING, + code = "minerva.npz.completed", + message = "Wrote Minerva NPZ compiler input.", + details = mapOf( + "path" to logicalPath, + "arrays" to arrays.size.toString(), + "bytes" to bytes.size.toString() + ) + ) + return model + } + + private fun arraysFor(intermediate: MinervaIntermediate): List { + val arrays = mutableListOf( + intArray("schema_version", listOf(1), listOf(schemaVersion)), + intArray("layer_count", listOf(1), listOf(intermediate.layerCount)), + intArray("input_shape", listOf(intermediate.input.shape.size), intermediate.input.shape), + intArray("output_shape", listOf(intermediate.output.shape.size), intermediate.output.shape) + ) + intermediate.layers.forEachIndexed { index, layer -> + arrays += floatArray( + name = "layer_${index}_w", + shape = layer.weights.shape, + values = requiredValues(layer.weights, layer.id, "layer_${index}_w") + ) + arrays += floatArray( + name = "layer_${index}_b", + shape = layer.bias?.shape ?: listOf(0), + values = layer.bias?.let { requiredValues(it, layer.id, "layer_${index}_b") } ?: emptyList() + ) + arrays += intArray( + name = "layer_${index}_act", + shape = listOf(1), + values = listOf(activationCode(layer.activation)) + ) + arrays += intArray( + name = "layer_${index}_input_shape", + shape = listOf(layer.input.shape.size), + values = layer.input.shape + ) + arrays += intArray( + name = "layer_${index}_output_shape", + shape = listOf(layer.output.shape.size), + values = layer.output.shape + ) + } + validateSchema(intermediate, arrays) + return arrays + } + + private fun requiredValues( + tensor: MinervaTensorRef, + layerId: String, + arrayName: String + ): List { + return tensor.values ?: throw MinervaNpzSchemaException( + message = "Tensor '${tensor.id}' has no numeric values for '$arrayName'.", + code = "minerva.npz.missing_values", + layerId = layerId, + arrayName = arrayName, + details = mapOf( + "tensorId" to tensor.id, + "role" to tensor.role.name, + "remediation" to "Attach numeric initializer values to weight and bias TensorSpec metadata before export." + ) + ) + } + + private fun validateSchema(intermediate: MinervaIntermediate, arrays: List) { + val names = arrays.map { it.name } + val duplicates = names.groupingBy { it }.eachCount().filter { it.value > 1 }.keys + if (duplicates.isNotEmpty()) { + throw MinervaNpzSchemaException( + message = "Minerva NPZ schema contains duplicate array names: $duplicates.", + code = "minerva.npz.duplicate_arrays", + details = mapOf("arrayNames" to duplicates.joinToString(",")) + ) + } + intermediate.layers.forEachIndexed { index, layer -> + requireArray(names, "layer_${index}_w", layer.id) + requireArray(names, "layer_${index}_b", layer.id) + requireArray(names, "layer_${index}_act", layer.id) + } + } + + private fun requireArray(names: List, name: String, layerId: String) { + if (name !in names) { + throw MinervaNpzSchemaException( + message = "Minerva NPZ schema is missing required array '$name'.", + code = "minerva.npz.missing_array", + layerId = layerId, + arrayName = name + ) + } + } + + private fun activationCode(activation: MinervaActivation?): Int { + return when (activation) { + null -> 0 + MinervaActivation.RELU -> 1 + MinervaActivation.SIGMOID -> 2 + MinervaActivation.TANH -> 3 + } + } + + private fun floatArray(name: String, shape: List, values: List): MinervaNpzArray { + return MinervaNpzArray(name = name, dtype = MinervaNpzDType.FLOAT32, shape = shape, floatData = values) + } + + private fun intArray(name: String, shape: List, values: List): MinervaNpzArray { + return MinervaNpzArray(name = name, dtype = MinervaNpzDType.INT32, shape = shape, intData = values) + } +} + +private object MinervaNpzArchiveWriter { + fun write(arrays: List): ByteArray { + val entries = arrays.map { array -> + ZipEntryData("${array.name}.npy", NpyWriter.write(array)) + } + return ZipStoreWriter.write(entries) + } +} + +private object NpyWriter { + fun write(array: MinervaNpzArray): ByteArray { + val payload = ByteAccumulator() + when (array.dtype) { + MinervaNpzDType.FLOAT32 -> array.floatData.forEach { payload.writeIntLE(it.toRawBits()) } + MinervaNpzDType.INT32 -> array.intData.forEach { payload.writeIntLE(it) } + } + val header = header(array) + val output = ByteAccumulator() + output.writeByte(0x93) + output.writeAscii("NUMPY") + output.writeByte(1) + output.writeByte(0) + output.writeShortLE(header.size) + output.writeBytes(header) + output.writeBytes(payload.toByteArray()) + return output.toByteArray() + } + + private fun header(array: MinervaNpzArray): ByteArray { + val shapeText = when (array.shape.size) { + 1 -> "(${array.shape.single()},)" + else -> array.shape.joinToString(prefix = "(", postfix = ")") + } + val raw = "{'descr': '${array.dtype.numpyDescriptor}', 'fortran_order': False, 'shape': $shapeText, }" + val preambleSize = 10 + val padding = (16 - ((preambleSize + raw.length + 1) % 16)) % 16 + return (raw + " ".repeat(padding) + "\n").encodeToByteArray() + } +} + +private data class ZipEntryData(val name: String, val data: ByteArray) + +private object ZipStoreWriter { + fun write(entries: List): ByteArray { + val output = ByteAccumulator() + val centralEntries = mutableListOf() + entries.forEach { entry -> + val offset = output.size + val nameBytes = entry.name.encodeToByteArray() + val crc = Crc32.compute(entry.data) + output.writeIntLE(0x04034b50) + output.writeShortLE(20) + output.writeShortLE(0) + output.writeShortLE(0) + output.writeShortLE(0) + output.writeShortLE(0) + output.writeIntLE(crc) + output.writeIntLE(entry.data.size) + output.writeIntLE(entry.data.size) + output.writeShortLE(nameBytes.size) + output.writeShortLE(0) + output.writeBytes(nameBytes) + output.writeBytes(entry.data) + centralEntries += CentralDirectoryEntry(entry.name, crc, entry.data.size, offset) + } + + val centralStart = output.size + centralEntries.forEach { entry -> + val nameBytes = entry.name.encodeToByteArray() + output.writeIntLE(0x02014b50) + output.writeShortLE(20) + output.writeShortLE(20) + output.writeShortLE(0) + output.writeShortLE(0) + output.writeShortLE(0) + output.writeShortLE(0) + output.writeIntLE(entry.crc32) + output.writeIntLE(entry.size) + output.writeIntLE(entry.size) + output.writeShortLE(nameBytes.size) + output.writeShortLE(0) + output.writeShortLE(0) + output.writeShortLE(0) + output.writeShortLE(0) + output.writeIntLE(0) + output.writeIntLE(entry.localHeaderOffset) + output.writeBytes(nameBytes) + } + val centralSize = output.size - centralStart + output.writeIntLE(0x06054b50) + output.writeShortLE(0) + output.writeShortLE(0) + output.writeShortLE(centralEntries.size) + output.writeShortLE(centralEntries.size) + output.writeIntLE(centralSize) + output.writeIntLE(centralStart) + output.writeShortLE(0) + return output.toByteArray() + } +} + +private data class CentralDirectoryEntry( + val name: String, + val crc32: Int, + val size: Int, + val localHeaderOffset: Int +) + +private object Crc32 { + private val table: IntArray = IntArray(256) { index -> + var crc = index + repeat(8) { + crc = if ((crc and 1) != 0) { + (crc ushr 1) xor 0xedb88320.toInt() + } else { + crc ushr 1 + } + } + crc + } + + fun compute(bytes: ByteArray): Int { + var crc = -1 + bytes.forEach { byte -> + crc = (crc ushr 8) xor table[(crc xor byte.toInt()) and 0xff] + } + return crc xor -1 + } +} + +private class ByteAccumulator { + private val bytes = mutableListOf() + + val size: Int + get() = bytes.size + + fun writeByte(value: Int) { + bytes += value.toByte() + } + + fun writeShortLE(value: Int) { + writeByte(value and 0xff) + writeByte((value ushr 8) and 0xff) + } + + fun writeIntLE(value: Int) { + writeByte(value and 0xff) + writeByte((value ushr 8) and 0xff) + writeByte((value ushr 16) and 0xff) + writeByte((value ushr 24) and 0xff) + } + + fun writeAscii(value: String) { + writeBytes(value.encodeToByteArray()) + } + + fun writeBytes(value: ByteArray) { + value.forEach { bytes += it } + } + + fun toByteArray(): ByteArray { + return ByteArray(bytes.size) { index -> bytes[index] } + } +} + diff --git a/skainet-compile/skainet-compile-minerva/src/commonTest/kotlin/sk/ainet/compile/minerva/MinervaExportFacadeTest.kt b/skainet-compile/skainet-compile-minerva/src/commonTest/kotlin/sk/ainet/compile/minerva/MinervaExportFacadeTest.kt index 45f40e64..c040ddb9 100644 --- a/skainet-compile/skainet-compile-minerva/src/commonTest/kotlin/sk/ainet/compile/minerva/MinervaExportFacadeTest.kt +++ b/skainet-compile/skainet-compile-minerva/src/commonTest/kotlin/sk/ainet/compile/minerva/MinervaExportFacadeTest.kt @@ -6,6 +6,7 @@ import kotlin.test.assertFailsWith import kotlin.test.assertFalse import kotlin.test.assertNotNull import kotlin.test.assertTrue +import sk.ainet.compile.export.GraphExportArtifactRole import sk.ainet.compile.export.GraphExportStatus import sk.ainet.lang.graph.DefaultComputeGraph @@ -18,6 +19,7 @@ class MinervaExportFacadeTest { assertEquals(MinervaExportBackend.backendName, facade.backendName) assertEquals(MinervaExportBackend.backendName, facade.graphCanonicalizer.backendName) + assertEquals(MinervaExportBackend.backendName, facade.npzWriter.backendName) assertEquals(MinervaTarget.ATMEGA328P, options.target) assertEquals(MinervaQuantization.Q8, options.quantization) assertEquals("jvm-sequential-mlp-q8", options.toMetadata()["phaseOneScope"]) @@ -55,11 +57,14 @@ class MinervaExportFacadeTest { assertEquals(GraphExportStatus.FAILED, result.status) assertEquals(MinervaExportFailureKind.NOT_IMPLEMENTED, result.failure?.kind) assertEquals("minerva.export.not_implemented", result.failure?.code) - assertEquals("#693", result.failure?.details?.get("issue")) + assertEquals("#694", result.failure?.details?.get("issue")) assertTrue(result.diagnostics.infos.any { it.code == "minerva.graph.validation.passed" }) assertTrue(result.diagnostics.infos.any { it.code == "minerva.lowering.completed" }) + assertTrue(result.diagnostics.infos.any { it.code == "minerva.npz.completed" }) assertTrue(result.compatibilityReport?.compatible == true) assertEquals(1, result.intermediate?.layerCount) + assertTrue(assertNotNull(result.npzModel).bytes.isNotEmpty()) + assertEquals("model.npz", result.artifacts.single { it.role == GraphExportArtifactRole.INTERMEDIATE }.path) assertTrue(result.metadata["target"] == MinervaTarget.ATMEGA328P.compilerId) assertFailsWith { result.requireSuccess() @@ -74,6 +79,7 @@ class MinervaExportFacadeTest { assertEquals(MinervaExportFailureKind.NOT_IMPLEMENTED, result.failure?.kind) assertTrue(result.compatibilityReport?.compatible == true) assertEquals(MinervaActivation.RELU, result.intermediate?.layers?.single()?.activation) + assertEquals(listOf("layer_0_w", "layer_0_b", "layer_0_act"), result.npzModel?.arrayNames?.filter { it.startsWith("layer_0") }?.take(3)) } @Test @@ -137,5 +143,8 @@ class MinervaExportFacadeTest { assertEquals(MinervaTensorRole.OUTPUT, intermediate.output.role) assertEquals("matmul", intermediate.layers.single().id) assertEquals("1", result.failure?.details?.get("layers")) + assertEquals("#694", result.failure?.details?.get("issue")) + assertEquals("model.npz", result.failure?.details?.get("npzPath")) + assertTrue(assertNotNull(result.npzModel).bytes.isNotEmpty()) } } diff --git a/skainet-compile/skainet-compile-minerva/src/commonTest/kotlin/sk/ainet/compile/minerva/MinervaGraphCanonicalizerTest.kt b/skainet-compile/skainet-compile-minerva/src/commonTest/kotlin/sk/ainet/compile/minerva/MinervaGraphCanonicalizerTest.kt index 2c791946..5da945e2 100644 --- a/skainet-compile/skainet-compile-minerva/src/commonTest/kotlin/sk/ainet/compile/minerva/MinervaGraphCanonicalizerTest.kt +++ b/skainet-compile/skainet-compile-minerva/src/commonTest/kotlin/sk/ainet/compile/minerva/MinervaGraphCanonicalizerTest.kt @@ -33,6 +33,8 @@ class MinervaGraphCanonicalizerTest { assertEquals(MinervaActivation.RELU, layer.activation) assertEquals(listOf("matmul", "bias_add", "relu"), layer.sourceNodeIds) assertEquals(listOf(1, 3), layer.output.shape) + assertEquals(layer.weights.elementCount, layer.weights.values?.size) + assertEquals(layer.bias?.elementCount, layer.bias?.values?.size) assertTrue(layer.hasBias) assertTrue(context.diagnostics.any { it.code == "minerva.lowering.started" }) assertTrue(context.diagnostics.any { it.code == "minerva.lowering.completed" }) @@ -75,4 +77,3 @@ class MinervaGraphCanonicalizerTest { ) } } - diff --git a/skainet-compile/skainet-compile-minerva/src/commonTest/kotlin/sk/ainet/compile/minerva/MinervaGraphFixtures.kt b/skainet-compile/skainet-compile-minerva/src/commonTest/kotlin/sk/ainet/compile/minerva/MinervaGraphFixtures.kt index 82e7eb2c..01f6c0d6 100644 --- a/skainet-compile/skainet-compile-minerva/src/commonTest/kotlin/sk/ainet/compile/minerva/MinervaGraphFixtures.kt +++ b/skainet-compile/skainet-compile-minerva/src/commonTest/kotlin/sk/ainet/compile/minerva/MinervaGraphFixtures.kt @@ -8,6 +8,7 @@ import sk.ainet.lang.tensor.ops.GenericOperation import sk.ainet.lang.tensor.ops.InputOperation import sk.ainet.lang.tensor.ops.MatmulOperation import sk.ainet.lang.tensor.ops.ReluOperation +import sk.ainet.lang.tensor.ops.SigmoidOperation import sk.ainet.lang.tensor.ops.TensorSpec import sk.ainet.lang.types.DType @@ -27,9 +28,9 @@ internal fun validMinervaMlpGraph( outputWidth: Int = 3 ): DefaultComputeGraph { val xSpec = spec("x", 1, inputWidth) - val wSpec = spec("w", inputWidth, outputWidth) + val wSpec = spec("w", inputWidth, outputWidth, values = linearValues(inputWidth * outputWidth, start = 0.1f)) val matmulSpec = spec("matmul", 1, outputWidth) - val biasSpec = spec("bias", 1, outputWidth) + val biasSpec = spec("bias", 1, outputWidth, values = linearValues(outputWidth, start = 0.01f)) val addSpec = spec("biased", 1, outputWidth) val ySpec = spec("y", 1, outputWidth) @@ -67,6 +68,78 @@ internal fun validMinervaMlpGraph( ) } +internal fun twoLayerMinervaMlpGraph(): DefaultComputeGraph { + val xSpec = spec("x", 1, 4) + val w0Spec = spec("w0", 4, 3, values = linearValues(12, start = 0.1f)) + val matmul0Spec = spec("matmul0", 1, 3) + val b0Spec = spec("b0", 1, 3, values = linearValues(3, start = 0.01f)) + val add0Spec = spec("add0", 1, 3) + val relu0Spec = spec("relu0", 1, 3) + val w1Spec = spec("w1", 3, 2, values = linearValues(6, start = -0.2f)) + val matmul1Spec = spec("matmul1", 1, 2) + val b1Spec = spec("b1", 1, 2, values = linearValues(2, start = -0.03f)) + val add1Spec = spec("add1", 1, 2) + val ySpec = spec("y", 1, 2) + + val x = inputNode("input", xSpec) + val w0 = inputNode("weight0", w0Spec) + val matmul0 = GraphNode( + id = "matmul0", + operation = MatmulOperation(), + inputs = listOf(xSpec, w0Spec), + outputs = listOf(matmul0Spec) + ) + val b0 = inputNode("bias0", b0Spec) + val add0 = GraphNode( + id = "bias_add0", + operation = AddOperation(), + inputs = listOf(matmul0Spec, b0Spec), + outputs = listOf(add0Spec) + ) + val relu0 = GraphNode( + id = "relu0", + operation = ReluOperation(), + inputs = listOf(add0Spec), + outputs = listOf(relu0Spec) + ) + val w1 = inputNode("weight1", w1Spec) + val matmul1 = GraphNode( + id = "matmul1", + operation = MatmulOperation(), + inputs = listOf(relu0Spec, w1Spec), + outputs = listOf(matmul1Spec) + ) + val b1 = inputNode("bias1", b1Spec) + val add1 = GraphNode( + id = "bias_add1", + operation = AddOperation(), + inputs = listOf(matmul1Spec, b1Spec), + outputs = listOf(add1Spec) + ) + val sigmoid = GraphNode( + id = "sigmoid", + operation = SigmoidOperation(), + inputs = listOf(add1Spec), + outputs = listOf(ySpec) + ) + + return graphOf( + nodes = listOf(x, w0, matmul0, b0, add0, relu0, w1, matmul1, b1, add1, sigmoid), + edges = listOf( + edge("x_to_matmul0", x, matmul0, xSpec, destinationInputIndex = 0), + edge("w0_to_matmul0", w0, matmul0, w0Spec, destinationInputIndex = 1), + edge("matmul0_to_add0", matmul0, add0, matmul0Spec, destinationInputIndex = 0), + edge("b0_to_add0", b0, add0, b0Spec, destinationInputIndex = 1), + edge("add0_to_relu0", add0, relu0, add0Spec), + edge("relu0_to_matmul1", relu0, matmul1, relu0Spec, destinationInputIndex = 0), + edge("w1_to_matmul1", w1, matmul1, w1Spec, destinationInputIndex = 1), + edge("matmul1_to_add1", matmul1, add1, matmul1Spec, destinationInputIndex = 0), + edge("b1_to_add1", b1, add1, b1Spec, destinationInputIndex = 1), + edge("add1_to_sigmoid", add1, sigmoid, add1Spec) + ) + ) +} + internal fun unsupportedMinervaOperationGraph(): DefaultComputeGraph { val inputSpec = spec("x", 1, 4) val outputSpec = spec("conv_out", 1, 4) @@ -159,8 +232,13 @@ private fun inputNode(id: String, output: TensorSpec): GraphNode { ) } -private fun spec(name: String, vararg shape: Int): TensorSpec { - return TensorSpec(name, shape.toList(), "Float32") +private fun spec(name: String, vararg shape: Int, values: List? = null): TensorSpec { + val metadata: Map = values?.let { mapOf("values" to it.toFloatArray()) } ?: emptyMap() + return TensorSpec(name, shape.toList(), "Float32", metadata = metadata) +} + +private fun linearValues(count: Int, start: Float): List { + return List(count) { index -> start + (index * 0.05f) } } private fun graphOf(nodes: List, edges: List): DefaultComputeGraph { diff --git a/skainet-compile/skainet-compile-minerva/src/commonTest/kotlin/sk/ainet/compile/minerva/MinervaNpzModelWriterTest.kt b/skainet-compile/skainet-compile-minerva/src/commonTest/kotlin/sk/ainet/compile/minerva/MinervaNpzModelWriterTest.kt new file mode 100644 index 00000000..a0e03ac1 --- /dev/null +++ b/skainet-compile/skainet-compile-minerva/src/commonTest/kotlin/sk/ainet/compile/minerva/MinervaNpzModelWriterTest.kt @@ -0,0 +1,158 @@ +package sk.ainet.compile.minerva + +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertTrue +import sk.ainet.compile.export.GraphExportArtifactRole +import sk.ainet.compile.export.GraphExportContext + +class MinervaNpzModelWriterTest { + + @Test + fun writesDeterministicNpzSchemaForTwoLayerMlp() { + val context = minervaContext(projectName = "TwoLayerMlp") + val intermediate = MinervaGraphCanonicalizer().convert(twoLayerMinervaMlpGraph(), context) + val writer = MinervaNpzModelWriter() + + val first = writer.write(intermediate, context) + val second = writer.write( + MinervaGraphCanonicalizer().convert(twoLayerMinervaMlpGraph(), minervaContext(projectName = "TwoLayerMlp")), + minervaContext(projectName = "SecondWriterContext") + ) + + assertTrue(first.bytes.contentEquals(second.bytes)) + assertEquals( + listOf( + "schema_version", + "layer_count", + "input_shape", + "output_shape", + "layer_0_w", + "layer_0_b", + "layer_0_act", + "layer_0_input_shape", + "layer_0_output_shape", + "layer_1_w", + "layer_1_b", + "layer_1_act", + "layer_1_input_shape", + "layer_1_output_shape" + ), + first.arrayNames + ) + assertEquals(1, first.schemaVersion) + assertEquals("1", first.metadata["schemaVersion"]) + assertEquals("2", first.metadata["layerCount"]) + assertEquals("1x4", first.metadata["inputShape"]) + assertEquals("1x2", first.metadata["outputShape"]) + assertEquals(listOf(4, 3), first.array("layer_0_w").shape) + assertEquals(12, first.array("layer_0_w").floatData.size) + assertEquals(listOf(1), first.array("layer_0_act").intData) + assertEquals(listOf(2), first.array("layer_1_act").intData) + assertTrue(context.artifacts.any { it.path == "model.npz" && it.role == GraphExportArtifactRole.INTERMEDIATE }) + assertTrue(context.diagnostics.any { it.code == "minerva.npz.completed" }) + } + + @Test + fun generatedArchiveContainsReadableNpyEntries() { + val context = minervaContext(projectName = "ReadableNpz") + val intermediate = MinervaGraphCanonicalizer().convert(twoLayerMinervaMlpGraph(), context) + val model = MinervaNpzModelWriter().write(intermediate, context) + + val entries = readZipStoreEntries(model.bytes) + + assertEquals(model.arrayNames.map { "$it.npy" }, entries.map { it.name }) + assertTrue(entries.all { it.data.startsWithNpyMagic() }) + assertTrue(entries.single { it.name == "layer_0_w.npy" }.npyHeader().contains("'descr': ' + if (tensor.id == layer.weights.id) tensor.copy(values = null) else tensor + } + ) + + val exception = assertFailsWith { + MinervaNpzModelWriter().write(broken, context) + } + + assertEquals("minerva.npz.missing_values", exception.code) + assertEquals(layer.id, exception.layerId) + assertEquals("layer_0_w", exception.arrayName) + assertEquals(layer.weights.id, exception.details["tensorId"]) + } + + private fun minervaContext(projectName: String): GraphExportContext { + val options = minervaTestOptions(projectName = projectName) + return GraphExportContext( + backendName = MinervaExportBackend.backendName, + targetName = options.projectName, + metadata = options.toMetadata() + ) + } + + private fun MinervaNpzModel.array(name: String): MinervaNpzArray { + return arrays.single { it.name == name } + } + + private fun ByteArray.startsWithNpyMagic(): Boolean { + return size >= 6 && + this[0] == 0x93.toByte() && + this[1] == 'N'.code.toByte() && + this[2] == 'U'.code.toByte() && + this[3] == 'M'.code.toByte() && + this[4] == 'P'.code.toByte() && + this[5] == 'Y'.code.toByte() + } + + private fun ZipStoreEntry.npyHeader(): String { + val headerLength = data.readShortLE(offset = 8) + return data.copyOfRange(10, 10 + headerLength).decodeToString() + } + + private fun readZipStoreEntries(bytes: ByteArray): List { + val entries = mutableListOf() + var offset = 0 + while (offset + LOCAL_HEADER_SIZE <= bytes.size && bytes.readIntLE(offset) == LOCAL_FILE_HEADER_SIGNATURE) { + val compressedSize = bytes.readIntLE(offset + 18) + val nameLength = bytes.readShortLE(offset + 26) + val extraLength = bytes.readShortLE(offset + 28) + val nameStart = offset + LOCAL_HEADER_SIZE + val dataStart = nameStart + nameLength + extraLength + val dataEnd = dataStart + compressedSize + val name = bytes.copyOfRange(nameStart, nameStart + nameLength).decodeToString() + entries += ZipStoreEntry(name = name, data = bytes.copyOfRange(dataStart, dataEnd)) + offset = dataEnd + } + assertTrue(entries.isNotEmpty()) + assertEquals(CENTRAL_DIRECTORY_SIGNATURE, bytes.readIntLE(offset)) + return entries + } + + private fun ByteArray.readShortLE(offset: Int): Int { + return (this[offset].toInt() and 0xff) or + ((this[offset + 1].toInt() and 0xff) shl 8) + } + + private fun ByteArray.readIntLE(offset: Int): Int { + return readShortLE(offset) or (readShortLE(offset + 2) shl 16) + } + + private data class ZipStoreEntry(val name: String, val data: ByteArray) + + private companion object { + const val LOCAL_FILE_HEADER_SIGNATURE: Int = 0x04034b50 + const val CENTRAL_DIRECTORY_SIGNATURE: Int = 0x02014b50 + const val LOCAL_HEADER_SIZE: Int = 30 + } +}