diff --git a/skainet-compile/skainet-compile-minerva/api/skainet-compile-minerva.api b/skainet-compile/skainet-compile-minerva/api/skainet-compile-minerva.api index f02dd290..4fd4cd45 100644 --- a/skainet-compile/skainet-compile-minerva/api/skainet-compile-minerva.api +++ b/skainet-compile/skainet-compile-minerva/api/skainet-compile-minerva.api @@ -1,3 +1,12 @@ +public final class sk/ainet/compile/minerva/MinervaActivation : java/lang/Enum { + public static final field RELU Lsk/ainet/compile/minerva/MinervaActivation; + public static final field SIGMOID Lsk/ainet/compile/minerva/MinervaActivation; + public static final field TANH Lsk/ainet/compile/minerva/MinervaActivation; + public static fun getEntries ()Lkotlin/enums/EnumEntries; + public static fun valueOf (Ljava/lang/String;)Lsk/ainet/compile/minerva/MinervaActivation; + public static fun values ()[Lsk/ainet/compile/minerva/MinervaActivation; +} + public final class sk/ainet/compile/minerva/MinervaCompatibilityIssue { public fun (Lsk/ainet/compile/minerva/MinervaCompatibilityIssueKind;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/util/Map;)V public synthetic fun (Lsk/ainet/compile/minerva/MinervaCompatibilityIssueKind;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/util/Map;ILkotlin/jvm/internal/DefaultConstructorMarker;)V @@ -114,12 +123,14 @@ public final class sk/ainet/compile/minerva/MinervaExportFacade { public fun ()V public fun (Ljava/lang/String;)V public fun (Ljava/lang/String;Lsk/ainet/compile/minerva/MinervaCompatibilityValidator;)V - public synthetic fun (Ljava/lang/String;Lsk/ainet/compile/minerva/MinervaCompatibilityValidator;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public fun (Ljava/lang/String;Lsk/ainet/compile/minerva/MinervaCompatibilityValidator;Lsk/ainet/compile/minerva/MinervaGraphCanonicalizer;)V + public synthetic fun (Ljava/lang/String;Lsk/ainet/compile/minerva/MinervaCompatibilityValidator;Lsk/ainet/compile/minerva/MinervaGraphCanonicalizer;ILkotlin/jvm/internal/DefaultConstructorMarker;)V public final fun exportGraph (Lsk/ainet/lang/graph/ComputeGraph;Lsk/ainet/compile/minerva/MinervaExportOptions;)Lsk/ainet/compile/minerva/MinervaExportResult; public final fun exportModel (Ljava/lang/Object;Lkotlin/jvm/functions/Function1;Lsk/ainet/compile/minerva/MinervaExportOptions;)Lsk/ainet/compile/minerva/MinervaExportResult; public final fun exportModel (Ljava/lang/Object;Lsk/ainet/compile/minerva/MinervaExportOptions;)Lsk/ainet/compile/minerva/MinervaExportResult; public final fun getBackendName ()Ljava/lang/String; public final fun getCompatibilityValidator ()Lsk/ainet/compile/minerva/MinervaCompatibilityValidator; + public final fun getGraphCanonicalizer ()Lsk/ainet/compile/minerva/MinervaGraphCanonicalizer; } public final class sk/ainet/compile/minerva/MinervaExportFailure { @@ -145,6 +156,7 @@ public final class sk/ainet/compile/minerva/MinervaExportFailure { public final class sk/ainet/compile/minerva/MinervaExportFailureKind : java/lang/Enum { public static final field COMPATIBILITY_VALIDATION_FAILED Lsk/ainet/compile/minerva/MinervaExportFailureKind; public static final field GRAPH_VALIDATION_FAILED Lsk/ainet/compile/minerva/MinervaExportFailureKind; + public static final field LOWERING_FAILED Lsk/ainet/compile/minerva/MinervaExportFailureKind; public static final field NOT_IMPLEMENTED Lsk/ainet/compile/minerva/MinervaExportFailureKind; public static final field RECORDING_FAILED Lsk/ainet/compile/minerva/MinervaExportFailureKind; public static final field UNSUPPORTED_MODEL_TYPE Lsk/ainet/compile/minerva/MinervaExportFailureKind; @@ -191,8 +203,8 @@ public final class sk/ainet/compile/minerva/MinervaExportOptions { } public final class sk/ainet/compile/minerva/MinervaExportResult { - public fun (Lsk/ainet/compile/minerva/MinervaExportOptions;Lsk/ainet/compile/export/GraphExportStatus;Lsk/ainet/compile/minerva/MinervaExportBundle;Lsk/ainet/compile/export/GraphExportDiagnosticReport;Ljava/util/List;Lsk/ainet/compile/minerva/MinervaExportFailure;Ljava/util/Map;Lsk/ainet/compile/minerva/MinervaCompatibilityReport;)V - public synthetic fun (Lsk/ainet/compile/minerva/MinervaExportOptions;Lsk/ainet/compile/export/GraphExportStatus;Lsk/ainet/compile/minerva/MinervaExportBundle;Lsk/ainet/compile/export/GraphExportDiagnosticReport;Ljava/util/List;Lsk/ainet/compile/minerva/MinervaExportFailure;Ljava/util/Map;Lsk/ainet/compile/minerva/MinervaCompatibilityReport;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public fun (Lsk/ainet/compile/minerva/MinervaExportOptions;Lsk/ainet/compile/export/GraphExportStatus;Lsk/ainet/compile/minerva/MinervaExportBundle;Lsk/ainet/compile/export/GraphExportDiagnosticReport;Ljava/util/List;Lsk/ainet/compile/minerva/MinervaExportFailure;Ljava/util/Map;Lsk/ainet/compile/minerva/MinervaCompatibilityReport;Lsk/ainet/compile/minerva/MinervaIntermediate;)V + public synthetic fun (Lsk/ainet/compile/minerva/MinervaExportOptions;Lsk/ainet/compile/export/GraphExportStatus;Lsk/ainet/compile/minerva/MinervaExportBundle;Lsk/ainet/compile/export/GraphExportDiagnosticReport;Ljava/util/List;Lsk/ainet/compile/minerva/MinervaExportFailure;Ljava/util/Map;Lsk/ainet/compile/minerva/MinervaCompatibilityReport;Lsk/ainet/compile/minerva/MinervaIntermediate;ILkotlin/jvm/internal/DefaultConstructorMarker;)V public final fun component1 ()Lsk/ainet/compile/minerva/MinervaExportOptions; public final fun component2 ()Lsk/ainet/compile/export/GraphExportStatus; public final fun component3 ()Lsk/ainet/compile/minerva/MinervaExportBundle; @@ -201,8 +213,9 @@ public final class sk/ainet/compile/minerva/MinervaExportResult { public final fun component6 ()Lsk/ainet/compile/minerva/MinervaExportFailure; public final fun component7 ()Ljava/util/Map; public final fun component8 ()Lsk/ainet/compile/minerva/MinervaCompatibilityReport; - public final fun copy (Lsk/ainet/compile/minerva/MinervaExportOptions;Lsk/ainet/compile/export/GraphExportStatus;Lsk/ainet/compile/minerva/MinervaExportBundle;Lsk/ainet/compile/export/GraphExportDiagnosticReport;Ljava/util/List;Lsk/ainet/compile/minerva/MinervaExportFailure;Ljava/util/Map;Lsk/ainet/compile/minerva/MinervaCompatibilityReport;)Lsk/ainet/compile/minerva/MinervaExportResult; - public static synthetic fun copy$default (Lsk/ainet/compile/minerva/MinervaExportResult;Lsk/ainet/compile/minerva/MinervaExportOptions;Lsk/ainet/compile/export/GraphExportStatus;Lsk/ainet/compile/minerva/MinervaExportBundle;Lsk/ainet/compile/export/GraphExportDiagnosticReport;Ljava/util/List;Lsk/ainet/compile/minerva/MinervaExportFailure;Ljava/util/Map;Lsk/ainet/compile/minerva/MinervaCompatibilityReport;ILjava/lang/Object;)Lsk/ainet/compile/minerva/MinervaExportResult; + public final fun component9 ()Lsk/ainet/compile/minerva/MinervaIntermediate; + public final fun copy (Lsk/ainet/compile/minerva/MinervaExportOptions;Lsk/ainet/compile/export/GraphExportStatus;Lsk/ainet/compile/minerva/MinervaExportBundle;Lsk/ainet/compile/export/GraphExportDiagnosticReport;Ljava/util/List;Lsk/ainet/compile/minerva/MinervaExportFailure;Ljava/util/Map;Lsk/ainet/compile/minerva/MinervaCompatibilityReport;Lsk/ainet/compile/minerva/MinervaIntermediate;)Lsk/ainet/compile/minerva/MinervaExportResult; + public static synthetic fun copy$default (Lsk/ainet/compile/minerva/MinervaExportResult;Lsk/ainet/compile/minerva/MinervaExportOptions;Lsk/ainet/compile/export/GraphExportStatus;Lsk/ainet/compile/minerva/MinervaExportBundle;Lsk/ainet/compile/export/GraphExportDiagnosticReport;Ljava/util/List;Lsk/ainet/compile/minerva/MinervaExportFailure;Ljava/util/Map;Lsk/ainet/compile/minerva/MinervaCompatibilityReport;Lsk/ainet/compile/minerva/MinervaIntermediate;ILjava/lang/Object;)Lsk/ainet/compile/minerva/MinervaExportResult; public fun equals (Ljava/lang/Object;)Z public final fun getArtifacts ()Ljava/util/List; public final fun getBundle ()Lsk/ainet/compile/minerva/MinervaExportBundle; @@ -210,6 +223,7 @@ public final class sk/ainet/compile/minerva/MinervaExportResult { public final fun getDiagnostics ()Lsk/ainet/compile/export/GraphExportDiagnosticReport; public final fun getFailed ()Z public final fun getFailure ()Lsk/ainet/compile/minerva/MinervaExportFailure; + public final fun getIntermediate ()Lsk/ainet/compile/minerva/MinervaIntermediate; public final fun getMetadata ()Ljava/util/Map; public final fun getOptions ()Lsk/ainet/compile/minerva/MinervaExportOptions; public final fun getStatus ()Lsk/ainet/compile/export/GraphExportStatus; @@ -219,6 +233,105 @@ public final class sk/ainet/compile/minerva/MinervaExportResult { public fun toString ()Ljava/lang/String; } +public final class sk/ainet/compile/minerva/MinervaGraphCanonicalizer : sk/ainet/compile/export/GraphExportConverter { + public fun ()V + public fun (Lsk/ainet/compile/minerva/MinervaLayerPatternRegistry;)V + public fun (Lsk/ainet/compile/minerva/MinervaLayerPatternRegistry;Ljava/lang/String;)V + public synthetic fun (Lsk/ainet/compile/minerva/MinervaLayerPatternRegistry;Ljava/lang/String;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public synthetic fun convert (Ljava/lang/Object;Lsk/ainet/compile/export/GraphExportContext;)Ljava/lang/Object; + public fun convert (Lsk/ainet/lang/graph/ComputeGraph;Lsk/ainet/compile/export/GraphExportContext;)Lsk/ainet/compile/minerva/MinervaIntermediate; + public fun getBackendName ()Ljava/lang/String; + public final fun getPatternRegistry ()Lsk/ainet/compile/minerva/MinervaLayerPatternRegistry; +} + +public final class sk/ainet/compile/minerva/MinervaIntermediate { + public fun (Ljava/lang/String;Lsk/ainet/compile/minerva/MinervaTarget;Lsk/ainet/compile/minerva/MinervaQuantization;Lsk/ainet/compile/minerva/MinervaTensorRef;Lsk/ainet/compile/minerva/MinervaTensorRef;Ljava/util/List;Ljava/util/List;Ljava/util/Map;)V + public synthetic fun (Ljava/lang/String;Lsk/ainet/compile/minerva/MinervaTarget;Lsk/ainet/compile/minerva/MinervaQuantization;Lsk/ainet/compile/minerva/MinervaTensorRef;Lsk/ainet/compile/minerva/MinervaTensorRef;Ljava/util/List;Ljava/util/List;Ljava/util/Map;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public final fun component1 ()Ljava/lang/String; + public final fun component2 ()Lsk/ainet/compile/minerva/MinervaTarget; + public final fun component3 ()Lsk/ainet/compile/minerva/MinervaQuantization; + public final fun component4 ()Lsk/ainet/compile/minerva/MinervaTensorRef; + public final fun component5 ()Lsk/ainet/compile/minerva/MinervaTensorRef; + public final fun component6 ()Ljava/util/List; + public final fun component7 ()Ljava/util/List; + public final fun component8 ()Ljava/util/Map; + public final fun copy (Ljava/lang/String;Lsk/ainet/compile/minerva/MinervaTarget;Lsk/ainet/compile/minerva/MinervaQuantization;Lsk/ainet/compile/minerva/MinervaTensorRef;Lsk/ainet/compile/minerva/MinervaTensorRef;Ljava/util/List;Ljava/util/List;Ljava/util/Map;)Lsk/ainet/compile/minerva/MinervaIntermediate; + public static synthetic fun copy$default (Lsk/ainet/compile/minerva/MinervaIntermediate;Ljava/lang/String;Lsk/ainet/compile/minerva/MinervaTarget;Lsk/ainet/compile/minerva/MinervaQuantization;Lsk/ainet/compile/minerva/MinervaTensorRef;Lsk/ainet/compile/minerva/MinervaTensorRef;Ljava/util/List;Ljava/util/List;Ljava/util/Map;ILjava/lang/Object;)Lsk/ainet/compile/minerva/MinervaIntermediate; + public fun equals (Ljava/lang/Object;)Z + public final fun getInput ()Lsk/ainet/compile/minerva/MinervaTensorRef; + public final fun getLayerCount ()I + public final fun getLayers ()Ljava/util/List; + public final fun getMetadata ()Ljava/util/Map; + public final fun getOutput ()Lsk/ainet/compile/minerva/MinervaTensorRef; + public final fun getProjectName ()Ljava/lang/String; + public final fun getQuantization ()Lsk/ainet/compile/minerva/MinervaQuantization; + public final fun getTarget ()Lsk/ainet/compile/minerva/MinervaTarget; + public final fun getTensors ()Ljava/util/List; + public fun hashCode ()I + public final fun requireLowered ()Lsk/ainet/compile/minerva/MinervaIntermediate; + public fun toString ()Ljava/lang/String; +} + +public final class sk/ainet/compile/minerva/MinervaLayer { + public fun (Ljava/lang/String;Lsk/ainet/compile/minerva/MinervaLayerKind;Lsk/ainet/compile/minerva/MinervaTensorRef;Lsk/ainet/compile/minerva/MinervaTensorRef;Lsk/ainet/compile/minerva/MinervaTensorRef;Lsk/ainet/compile/minerva/MinervaTensorRef;Lsk/ainet/compile/minerva/MinervaActivation;Ljava/util/List;Ljava/util/Map;)V + public synthetic fun (Ljava/lang/String;Lsk/ainet/compile/minerva/MinervaLayerKind;Lsk/ainet/compile/minerva/MinervaTensorRef;Lsk/ainet/compile/minerva/MinervaTensorRef;Lsk/ainet/compile/minerva/MinervaTensorRef;Lsk/ainet/compile/minerva/MinervaTensorRef;Lsk/ainet/compile/minerva/MinervaActivation;Ljava/util/List;Ljava/util/Map;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public final fun component1 ()Ljava/lang/String; + public final fun component2 ()Lsk/ainet/compile/minerva/MinervaLayerKind; + public final fun component3 ()Lsk/ainet/compile/minerva/MinervaTensorRef; + public final fun component4 ()Lsk/ainet/compile/minerva/MinervaTensorRef; + public final fun component5 ()Lsk/ainet/compile/minerva/MinervaTensorRef; + public final fun component6 ()Lsk/ainet/compile/minerva/MinervaTensorRef; + public final fun component7 ()Lsk/ainet/compile/minerva/MinervaActivation; + public final fun component8 ()Ljava/util/List; + public final fun component9 ()Ljava/util/Map; + public final fun copy (Ljava/lang/String;Lsk/ainet/compile/minerva/MinervaLayerKind;Lsk/ainet/compile/minerva/MinervaTensorRef;Lsk/ainet/compile/minerva/MinervaTensorRef;Lsk/ainet/compile/minerva/MinervaTensorRef;Lsk/ainet/compile/minerva/MinervaTensorRef;Lsk/ainet/compile/minerva/MinervaActivation;Ljava/util/List;Ljava/util/Map;)Lsk/ainet/compile/minerva/MinervaLayer; + public static synthetic fun copy$default (Lsk/ainet/compile/minerva/MinervaLayer;Ljava/lang/String;Lsk/ainet/compile/minerva/MinervaLayerKind;Lsk/ainet/compile/minerva/MinervaTensorRef;Lsk/ainet/compile/minerva/MinervaTensorRef;Lsk/ainet/compile/minerva/MinervaTensorRef;Lsk/ainet/compile/minerva/MinervaTensorRef;Lsk/ainet/compile/minerva/MinervaActivation;Ljava/util/List;Ljava/util/Map;ILjava/lang/Object;)Lsk/ainet/compile/minerva/MinervaLayer; + public fun equals (Ljava/lang/Object;)Z + public final fun getActivation ()Lsk/ainet/compile/minerva/MinervaActivation; + public final fun getBias ()Lsk/ainet/compile/minerva/MinervaTensorRef; + public final fun getHasBias ()Z + public final fun getId ()Ljava/lang/String; + public final fun getInput ()Lsk/ainet/compile/minerva/MinervaTensorRef; + public final fun getKind ()Lsk/ainet/compile/minerva/MinervaLayerKind; + public final fun getMetadata ()Ljava/util/Map; + public final fun getOutput ()Lsk/ainet/compile/minerva/MinervaTensorRef; + public final fun getSourceNodeIds ()Ljava/util/List; + public final fun getWeights ()Lsk/ainet/compile/minerva/MinervaTensorRef; + public fun hashCode ()I + public fun toString ()Ljava/lang/String; +} + +public final class sk/ainet/compile/minerva/MinervaLayerKind : java/lang/Enum { + public static final field DENSE Lsk/ainet/compile/minerva/MinervaLayerKind; + public static fun getEntries ()Lkotlin/enums/EnumEntries; + public static fun valueOf (Ljava/lang/String;)Lsk/ainet/compile/minerva/MinervaLayerKind; + public static fun values ()[Lsk/ainet/compile/minerva/MinervaLayerKind; +} + +public final class sk/ainet/compile/minerva/MinervaLayerPatternRegistry { + public fun ()V + public fun (Ljava/util/Set;)V + public fun (Ljava/util/Set;Ljava/lang/String;)V + public fun (Ljava/util/Set;Ljava/lang/String;Ljava/util/Map;)V + public synthetic fun (Ljava/util/Set;Ljava/lang/String;Ljava/util/Map;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public final fun activationFor (Ljava/lang/String;)Lsk/ainet/compile/minerva/MinervaActivation; + public final fun getActivationOperations ()Ljava/util/Map; + public final fun getBiasOperation ()Ljava/lang/String; + public final fun getLayerOperations ()Ljava/util/Set; + public final fun isBiasOperation (Ljava/lang/String;)Z + public final fun isLayerOperation (Ljava/lang/String;)Z + public final fun layerKindFor (Ljava/lang/String;)Lsk/ainet/compile/minerva/MinervaLayerKind; +} + +public final class sk/ainet/compile/minerva/MinervaLoweringException : java/lang/IllegalArgumentException { + public fun (Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/util/Map;)V + public synthetic fun (Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/util/Map;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public final fun getCode ()Ljava/lang/String; + public final fun getDetails ()Ljava/util/Map; + public final fun getNodeId ()Ljava/lang/String; + public final fun getOperationName ()Ljava/lang/String; +} + public final class sk/ainet/compile/minerva/MinervaQuantization : java/lang/Enum { public static final field Q8 Lsk/ainet/compile/minerva/MinervaQuantization; public final fun getCompilerId ()Ljava/lang/String; @@ -238,3 +351,39 @@ public final class sk/ainet/compile/minerva/MinervaTarget : java/lang/Enum { public static fun values ()[Lsk/ainet/compile/minerva/MinervaTarget; } +public final class sk/ainet/compile/minerva/MinervaTensorRef { + public fun (Ljava/lang/String;Ljava/lang/String;Ljava/util/List;Ljava/lang/String;Lsk/ainet/compile/minerva/MinervaTensorRole;Ljava/lang/String;Ljava/util/Map;)V + public synthetic fun (Ljava/lang/String;Ljava/lang/String;Ljava/util/List;Ljava/lang/String;Lsk/ainet/compile/minerva/MinervaTensorRole;Ljava/lang/String;Ljava/util/Map;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public final fun component1 ()Ljava/lang/String; + public final fun component2 ()Ljava/lang/String; + public final fun component3 ()Ljava/util/List; + public final fun component4 ()Ljava/lang/String; + public final fun component5 ()Lsk/ainet/compile/minerva/MinervaTensorRole; + public final fun component6 ()Ljava/lang/String; + public final fun component7 ()Ljava/util/Map; + public final fun copy (Ljava/lang/String;Ljava/lang/String;Ljava/util/List;Ljava/lang/String;Lsk/ainet/compile/minerva/MinervaTensorRole;Ljava/lang/String;Ljava/util/Map;)Lsk/ainet/compile/minerva/MinervaTensorRef; + public static synthetic fun copy$default (Lsk/ainet/compile/minerva/MinervaTensorRef;Ljava/lang/String;Ljava/lang/String;Ljava/util/List;Ljava/lang/String;Lsk/ainet/compile/minerva/MinervaTensorRole;Ljava/lang/String;Ljava/util/Map;ILjava/lang/Object;)Lsk/ainet/compile/minerva/MinervaTensorRef; + public fun equals (Ljava/lang/Object;)Z + public final fun getDtype ()Ljava/lang/String; + public final fun getElementCount ()I + public final fun getId ()Ljava/lang/String; + public final fun getMetadata ()Ljava/util/Map; + public final fun getName ()Ljava/lang/String; + public final fun getRole ()Lsk/ainet/compile/minerva/MinervaTensorRole; + public final fun getShape ()Ljava/util/List; + public final fun getSourceNodeId ()Ljava/lang/String; + public fun hashCode ()I + public fun toString ()Ljava/lang/String; +} + +public final class sk/ainet/compile/minerva/MinervaTensorRole : java/lang/Enum { + public static final field BIAS Lsk/ainet/compile/minerva/MinervaTensorRole; + public static final field INPUT Lsk/ainet/compile/minerva/MinervaTensorRole; + public static final field INTERMEDIATE Lsk/ainet/compile/minerva/MinervaTensorRole; + public static final field OUTPUT Lsk/ainet/compile/minerva/MinervaTensorRole; + public static final field WEIGHT Lsk/ainet/compile/minerva/MinervaTensorRole; + public static fun getEntries ()Lkotlin/enums/EnumEntries; + public static fun valueOf (Ljava/lang/String;)Lsk/ainet/compile/minerva/MinervaTensorRole; + public static fun values ()[Lsk/ainet/compile/minerva/MinervaTensorRole; +} + diff --git a/skainet-compile/skainet-compile-minerva/src/commonMain/kotlin/sk/ainet/compile/minerva/MinervaExportFacade.kt b/skainet-compile/skainet-compile-minerva/src/commonMain/kotlin/sk/ainet/compile/minerva/MinervaExportFacade.kt index 04839b2d..6b1ad930 100644 --- a/skainet-compile/skainet-compile-minerva/src/commonMain/kotlin/sk/ainet/compile/minerva/MinervaExportFacade.kt +++ b/skainet-compile/skainet-compile-minerva/src/commonMain/kotlin/sk/ainet/compile/minerva/MinervaExportFacade.kt @@ -20,7 +20,8 @@ import sk.ainet.tape.Execution */ public class MinervaExportFacade @kotlin.jvm.JvmOverloads constructor( public val backendName: String = MinervaExportBackend.backendName, - public val compatibilityValidator: MinervaCompatibilityValidator = MinervaCompatibilityValidator() + public val compatibilityValidator: MinervaCompatibilityValidator = MinervaCompatibilityValidator(), + public val graphCanonicalizer: MinervaGraphCanonicalizer = MinervaGraphCanonicalizer() ) { /** @@ -91,14 +92,23 @@ public class MinervaExportFacade @kotlin.jvm.JvmOverloads constructor( return compatibilityValidationFailedResult(options, context, compatibilityReport) } + val intermediate = try { + graphCanonicalizer.convert(graph, context) + } catch (exception: MinervaLoweringException) { + return loweringFailedResult(options, context, compatibilityReport, exception) + } + val failure = MinervaExportFailure( kind = MinervaExportFailureKind.NOT_IMPLEMENTED, - stage = GraphExportStage.LOWERING, + stage = GraphExportStage.WRITING, code = "minerva.export.not_implemented", - message = "Minerva export passed phase-one compatibility validation; lowering, compiler invocation, packaging, and verification are implemented in follow-up issues.", + message = "Minerva export lowered the graph to phase-one IR; compiler invocation, packaging, and verification are implemented in follow-up issues.", details = mapOf( - "nextStep" to "Implement MinervaGraphCanonicalizer", - "issue" to "#692" + "nextStep" to "Invoke the Minerva compiler and write the runtime project.", + "issue" to "#693", + "layers" to intermediate.layerCount.toString(), + "input" to intermediate.input.id, + "output" to intermediate.output.id ) ) context.error( @@ -107,7 +117,13 @@ public class MinervaExportFacade @kotlin.jvm.JvmOverloads constructor( message = failure.message, details = failure.details ) - return failedResult(options, context, failure, compatibilityReport) + return failedResult( + options = options, + context = context, + failure = failure, + compatibilityReport = compatibilityReport, + intermediate = intermediate + ) } private fun unsupportedModelResult(model: Any, options: MinervaExportOptions): MinervaExportResult { @@ -179,11 +195,40 @@ public class MinervaExportFacade @kotlin.jvm.JvmOverloads constructor( return failedResult(options, context, failure, report) } + private fun loweringFailedResult( + options: MinervaExportOptions, + context: GraphExportContext, + compatibilityReport: MinervaCompatibilityReport, + exception: MinervaLoweringException + ): MinervaExportResult { + val details = mutableMapOf( + "code" to exception.code, + "issue" to "#692" + ) + exception.nodeId?.let { details["nodeId"] = it } + exception.operationName?.let { details["operationName"] = it } + details += exception.details + val failure = MinervaExportFailure( + kind = MinervaExportFailureKind.LOWERING_FAILED, + stage = GraphExportStage.LOWERING, + code = exception.code, + message = exception.message ?: "Minerva graph lowering failed.", + details = details + ) + return failedResult( + options = options, + context = context, + failure = failure, + compatibilityReport = compatibilityReport + ) + } + private fun failedResult( options: MinervaExportOptions, context: GraphExportContext, failure: MinervaExportFailure, - compatibilityReport: MinervaCompatibilityReport? = null + compatibilityReport: MinervaCompatibilityReport? = null, + intermediate: MinervaIntermediate? = null ): MinervaExportResult { return MinervaExportResult( options = options, @@ -192,7 +237,8 @@ public class MinervaExportFacade @kotlin.jvm.JvmOverloads constructor( artifacts = context.artifacts, failure = failure, metadata = context.metadata, - compatibilityReport = compatibilityReport + compatibilityReport = compatibilityReport, + intermediate = intermediate ) } diff --git a/skainet-compile/skainet-compile-minerva/src/commonMain/kotlin/sk/ainet/compile/minerva/MinervaExportModels.kt b/skainet-compile/skainet-compile-minerva/src/commonMain/kotlin/sk/ainet/compile/minerva/MinervaExportModels.kt index 794fb9d6..b242b821 100644 --- a/skainet-compile/skainet-compile-minerva/src/commonMain/kotlin/sk/ainet/compile/minerva/MinervaExportModels.kt +++ b/skainet-compile/skainet-compile-minerva/src/commonMain/kotlin/sk/ainet/compile/minerva/MinervaExportModels.kt @@ -91,6 +91,7 @@ public enum class MinervaExportFailureKind { RECORDING_FAILED, GRAPH_VALIDATION_FAILED, COMPATIBILITY_VALIDATION_FAILED, + LOWERING_FAILED, NOT_IMPLEMENTED } @@ -201,7 +202,8 @@ public data class MinervaExportResult( public val artifacts: List = emptyList(), public val failure: MinervaExportFailure? = null, public val metadata: Map = emptyMap(), - public val compatibilityReport: MinervaCompatibilityReport? = null + public val compatibilityReport: MinervaCompatibilityReport? = null, + public val intermediate: MinervaIntermediate? = null ) { init { require(status != GraphExportStatus.SUCCESS || bundle != null) { diff --git a/skainet-compile/skainet-compile-minerva/src/commonMain/kotlin/sk/ainet/compile/minerva/MinervaGraphCanonicalizer.kt b/skainet-compile/skainet-compile-minerva/src/commonMain/kotlin/sk/ainet/compile/minerva/MinervaGraphCanonicalizer.kt new file mode 100644 index 00000000..3bb9fa83 --- /dev/null +++ b/skainet-compile/skainet-compile-minerva/src/commonMain/kotlin/sk/ainet/compile/minerva/MinervaGraphCanonicalizer.kt @@ -0,0 +1,386 @@ +package sk.ainet.compile.minerva + +import sk.ainet.compile.export.GraphExportContext +import sk.ainet.compile.export.GraphExportConverter +import sk.ainet.compile.export.GraphExportStage +import sk.ainet.lang.graph.ComputeGraph +import sk.ainet.lang.graph.GraphEdge +import sk.ainet.lang.graph.GraphNode +import sk.ainet.lang.tensor.ops.TensorSpec + +/** + * Registry for graph fragments that can be lowered into phase-one Minerva IR. + */ +public class MinervaLayerPatternRegistry @kotlin.jvm.JvmOverloads constructor( + layerOperations: Set = setOf("matmul", "dense", "linear"), + public val biasOperation: String = "add", + activationOperations: Map = mapOf( + "relu" to MinervaActivation.RELU, + "sigmoid" to MinervaActivation.SIGMOID, + "tanh" to MinervaActivation.TANH + ) +) { + public val layerOperations: Set = layerOperations.map { it.lowercase() }.toSet() + public val activationOperations: Map = + activationOperations.mapKeys { it.key.lowercase() } + + init { + require(this.layerOperations.isNotEmpty()) { "layerOperations cannot be empty" } + require(this.layerOperations.all { it.isNotBlank() }) { "layerOperations cannot contain blanks" } + require(biasOperation.isNotBlank()) { "biasOperation cannot be blank" } + require(this.activationOperations.isNotEmpty()) { "activationOperations cannot be empty" } + require(this.activationOperations.keys.all { it.isNotBlank() }) { + "activationOperations cannot contain blank operation names" + } + } + + public fun isLayerOperation(operationName: String): Boolean { + return operationName.lowercase() in layerOperations + } + + public fun isBiasOperation(operationName: String): Boolean { + return operationName.lowercase() == biasOperation.lowercase() + } + + public fun activationFor(operationName: String): MinervaActivation? { + return activationOperations[operationName.lowercase()] + } + + public fun layerKindFor(operationName: String): MinervaLayerKind? { + return if (isLayerOperation(operationName)) MinervaLayerKind.DENSE else null + } +} + +/** + * Exception raised when a validated graph still cannot be lowered into Minerva IR. + */ +public class MinervaLoweringException( + message: String, + public val code: String, + public val nodeId: String? = null, + public val operationName: String? = null, + public val details: Map = emptyMap() +) : IllegalArgumentException(message) { + init { + require(code.isNotBlank()) { "lowering exception code cannot be blank" } + } +} + +/** + * Lowers a compatible ComputeGraph into a compact Minerva intermediate. + */ +public class MinervaGraphCanonicalizer @kotlin.jvm.JvmOverloads constructor( + public val patternRegistry: MinervaLayerPatternRegistry = MinervaLayerPatternRegistry(), + override val backendName: String = MinervaExportBackend.backendName +) : GraphExportConverter { + + override fun convert(input: ComputeGraph, context: GraphExportContext): MinervaIntermediate { + context.info( + stage = GraphExportStage.LOWERING, + code = "minerva.lowering.started", + message = "Lowering compatible ComputeGraph to Minerva IR.", + details = mapOf("nodes" to input.nodes.size.toString()) + ) + + val topological = try { + input.getTopologicalOrder() + } catch (exception: Exception) { + fail( + context = context, + code = "minerva.lowering.topology_invalid", + message = exception.message ?: "Unable to determine graph topological order.", + details = mapOf("remediation" to "Validate the graph before Minerva lowering.") + ) + } + + val tensors = linkedMapOf() + val loweredNodeIds = mutableSetOf() + val layers = topological.mapNotNull { node -> + val kind = patternRegistry.layerKindFor(node.operationName) ?: return@mapNotNull null + lowerLayer(input, node, kind, context, tensors).also { layer -> + loweredNodeIds += layer.sourceNodeIds + } + } + + if (layers.isEmpty()) { + fail( + context = context, + code = "minerva.lowering.no_layers", + message = "No lowerable Minerva layer patterns were found.", + details = mapOf("remediation" to "Provide at least one matmul, dense, or linear layer.") + ) + } + + val unlowered = topological.firstOrNull { node -> + node.operationName.lowercase() != "input" && node.id !in loweredNodeIds + } + if (unlowered != null) { + fail( + context = context, + code = "minerva.lowering.unlowered_node", + message = "Node '${unlowered.id}' (${unlowered.operationName}) was not part of a Minerva layer pattern.", + node = unlowered, + details = mapOf("remediation" to "Use dense/matmul with optional add bias and activation fragments.") + ) + } + + val projectName = context.targetName ?: "minerva_model" + val intermediate = MinervaIntermediate( + projectName = projectName, + target = targetFromContext(context), + quantization = quantizationFromContext(context), + input = layers.first().input, + output = layers.last().output, + layers = layers, + tensors = tensors.values.toList(), + metadata = context.metadata + mapOf("lowering" to "minerva-phase-one") + ) + + context.info( + stage = GraphExportStage.LOWERING, + code = "minerva.lowering.completed", + message = "Lowered ComputeGraph to Minerva IR.", + details = mapOf( + "projectName" to intermediate.projectName, + "layers" to intermediate.layerCount.toString(), + "tensors" to intermediate.tensors.size.toString(), + "input" to intermediate.input.id, + "output" to intermediate.output.id + ) + ) + return intermediate + } + + private fun lowerLayer( + graph: ComputeGraph, + layerNode: GraphNode, + kind: MinervaLayerKind, + context: GraphExportContext, + tensors: MutableMap + ): MinervaLayer { + val incoming = incomingEdges(graph, layerNode) + if (incoming.size != 2) { + fail( + context = context, + code = "minerva.lowering.layer_arity", + message = "Layer node '${layerNode.id}' (${layerNode.operationName}) expects data and weight inputs.", + node = layerNode, + details = mapOf( + "expected" to "2", + "actual" to incoming.size.toString(), + "remediation" to "Lower dense layers as matmul(data, weight)." + ) + ) + } + + val dataEdge = incoming[0] + val weightEdge = incoming[1] + val sourceNodeIds = mutableListOf(layerNode.id) + var outputProducer = layerNode + var outputSpec = singleOutput(layerNode, context) + var bias: MinervaTensorRef? = null + var activation: MinervaActivation? = null + + val firstConsumer = singleConsumerOrNull(graph, layerNode, context) + if (firstConsumer != null && patternRegistry.isBiasOperation(firstConsumer.operationName)) { + val addIncoming = incomingEdges(graph, firstConsumer) + val layerToAdd = addIncoming.singleOrNull { it.source == layerNode } + ?: fail( + context = context, + code = "minerva.lowering.bias_add_source", + message = "Bias add node '${firstConsumer.id}' does not consume layer '${layerNode.id}'.", + node = firstConsumer, + details = mapOf("remediation" to "Place add directly after the layer output.") + ) + val biasEdge = addIncoming.singleOrNull { it != layerToAdd } + ?: fail( + context = context, + code = "minerva.lowering.bias_missing", + message = "Bias add node '${firstConsumer.id}' does not have a separate bias input.", + node = firstConsumer, + details = mapOf("remediation" to "Provide add(layer, bias) for bias lowering.") + ) + bias = tensorRef( + spec = biasEdge.tensorSpec, + role = MinervaTensorRole.BIAS, + sourceNode = biasEdge.source, + context = context, + tensors = tensors + ) + outputProducer = firstConsumer + outputSpec = singleOutput(firstConsumer, context) + sourceNodeIds += firstConsumer.id + } else if (firstConsumer != null) { + val directActivation = patternRegistry.activationFor(firstConsumer.operationName) + if (directActivation != null) { + activation = directActivation + outputProducer = firstConsumer + outputSpec = singleOutput(firstConsumer, context) + sourceNodeIds += firstConsumer.id + } + } + + val activationConsumer = singleConsumerOrNull(graph, outputProducer, context) + if (activation == null && activationConsumer != null) { + val activationKind = patternRegistry.activationFor(activationConsumer.operationName) + if (activationKind != null) { + val activationIncoming = incomingEdges(graph, activationConsumer) + if (activationIncoming.singleOrNull()?.source != outputProducer) { + fail( + context = context, + code = "minerva.lowering.activation_source", + message = "Activation node '${activationConsumer.id}' is not directly connected to the layer output.", + node = activationConsumer, + details = mapOf("remediation" to "Place activation directly after layer or bias add output.") + ) + } + activation = activationKind + outputProducer = activationConsumer + outputSpec = singleOutput(activationConsumer, context) + sourceNodeIds += activationConsumer.id + } + } + + singleConsumerOrNull(graph, outputProducer, context) + + val inputRole = if (dataEdge.source.operationName.lowercase() == "input") { + MinervaTensorRole.INPUT + } else { + MinervaTensorRole.INTERMEDIATE + } + val outputRole = if (graph.getOutputNodes().any { it == outputProducer }) { + MinervaTensorRole.OUTPUT + } else { + MinervaTensorRole.INTERMEDIATE + } + + return MinervaLayer( + id = layerNode.id, + kind = kind, + input = tensorRef(dataEdge.tensorSpec, inputRole, dataEdge.source, context, tensors), + weights = tensorRef(weightEdge.tensorSpec, MinervaTensorRole.WEIGHT, weightEdge.source, context, tensors), + bias = bias, + output = tensorRef(outputSpec, outputRole, outputProducer, context, tensors), + activation = activation, + sourceNodeIds = sourceNodeIds, + metadata = mapOf( + "operationName" to layerNode.operationName, + "operationType" to layerNode.operationType + ) + ) + } + + private fun singleOutput(node: GraphNode, context: GraphExportContext): TensorSpec { + if (node.outputs.size != 1) { + fail( + context = context, + code = "minerva.lowering.output_arity", + message = "Node '${node.id}' (${node.operationName}) must produce exactly one tensor.", + node = node, + details = mapOf("actual" to node.outputs.size.toString()) + ) + } + return node.outputs.single() + } + + private fun singleConsumerOrNull( + graph: ComputeGraph, + node: GraphNode, + context: GraphExportContext + ): GraphNode? { + val consumers = outgoingEdges(graph, node).map { it.destination } + if (consumers.size > 1) { + fail( + context = context, + code = "minerva.lowering.branching", + message = "Node '${node.id}' fans out to ${consumers.size} consumers during Minerva lowering.", + node = node, + details = mapOf( + "consumerNodeIds" to consumers.joinToString(",") { it.id }, + "remediation" to "Use one sequential MLP chain per Minerva export." + ) + ) + } + return consumers.singleOrNull() + } + + private fun incomingEdges(graph: ComputeGraph, node: GraphNode): List { + return graph.edges + .filter { it.destination == node } + .sortedBy { it.destinationInputIndex } + } + + private fun outgoingEdges(graph: ComputeGraph, node: GraphNode): List { + return graph.edges.filter { it.source == node } + } + + private fun tensorRef( + spec: TensorSpec, + role: MinervaTensorRole, + sourceNode: GraphNode, + context: GraphExportContext, + tensors: MutableMap + ): MinervaTensorRef { + val shape = spec.shape ?: fail( + context = context, + code = "minerva.lowering.dynamic_shape", + message = "Tensor '${spec.name}' on node '${sourceNode.id}' has no static shape.", + node = sourceNode, + details = mapOf("remediation" to "Run Minerva compatibility validation before lowering.") + ) + val id = tensorId(role, sourceNode.id, spec.name) + return tensors.getOrPut(id) { + MinervaTensorRef( + id = id, + name = spec.name, + shape = shape, + dtype = spec.dtype, + role = role, + sourceNodeId = sourceNode.id, + metadata = spec.metadata.mapValues { it.value.toString() } + ) + } + } + + private fun tensorId(role: MinervaTensorRole, sourceNodeId: String, tensorName: String): String { + val cleanName = tensorName.replace(Regex("[^A-Za-z0-9_]+"), "_").ifBlank { "tensor" } + val cleanNode = sourceNodeId.replace(Regex("[^A-Za-z0-9_]+"), "_").ifBlank { "node" } + return "${role.name.lowercase()}_${cleanNode}_$cleanName" + } + + private fun targetFromContext(context: GraphExportContext): MinervaTarget { + val compilerId = context.metadata["target"] + return MinervaTarget.values().firstOrNull { it.compilerId == compilerId } + ?: MinervaTarget.ATMEGA328P + } + + private fun quantizationFromContext(context: GraphExportContext): MinervaQuantization { + val compilerId = context.metadata["quantization"] + return MinervaQuantization.values().firstOrNull { it.compilerId == compilerId } + ?: MinervaQuantization.Q8 + } + + private fun fail( + context: GraphExportContext, + code: String, + message: String, + node: GraphNode? = null, + details: Map = emptyMap() + ): Nothing { + context.error( + stage = GraphExportStage.LOWERING, + code = code, + message = message, + nodeId = node?.id, + operationName = node?.operationName, + details = details + ) + throw MinervaLoweringException( + message = message, + code = code, + nodeId = node?.id, + operationName = node?.operationName, + details = details + ) + } +} diff --git a/skainet-compile/skainet-compile-minerva/src/commonMain/kotlin/sk/ainet/compile/minerva/MinervaIntermediateModels.kt b/skainet-compile/skainet-compile-minerva/src/commonMain/kotlin/sk/ainet/compile/minerva/MinervaIntermediateModels.kt new file mode 100644 index 00000000..5855efe5 --- /dev/null +++ b/skainet-compile/skainet-compile-minerva/src/commonMain/kotlin/sk/ainet/compile/minerva/MinervaIntermediateModels.kt @@ -0,0 +1,102 @@ +package sk.ainet.compile.minerva + +/** + * Canonical phase-one Minerva layer kinds. + */ +public enum class MinervaLayerKind { + DENSE +} + +/** + * Supported Minerva activation functions. + */ +public enum class MinervaActivation { + RELU, + SIGMOID, + TANH +} + +/** + * Role assigned to a tensor in the lowered Minerva IR. + */ +public enum class MinervaTensorRole { + INPUT, + WEIGHT, + BIAS, + INTERMEDIATE, + OUTPUT +} + +/** + * Tensor reference used by the Minerva intermediate representation. + */ +public data class MinervaTensorRef( + public val id: String, + public val name: String, + public val shape: List, + public val dtype: String, + public val role: MinervaTensorRole, + public val sourceNodeId: String? = null, + public val metadata: Map = emptyMap() +) { + init { + require(id.isNotBlank()) { "tensor id cannot be blank" } + require(name.isNotBlank()) { "tensor name cannot be blank" } + require(shape.isNotEmpty()) { "tensor shape cannot be empty" } + require(shape.all { it > 0 }) { "tensor shape dimensions must be positive" } + require(dtype.isNotBlank()) { "tensor dtype cannot be blank" } + } + + public val elementCount: Int + get() = shape.fold(1) { acc, dim -> acc * dim } +} + +/** + * A lowered phase-one Minerva layer pattern. + */ +public data class MinervaLayer( + public val id: String, + public val kind: MinervaLayerKind, + public val input: MinervaTensorRef, + public val weights: MinervaTensorRef, + public val bias: MinervaTensorRef? = null, + public val output: MinervaTensorRef, + public val activation: MinervaActivation? = null, + public val sourceNodeIds: List, + public val metadata: Map = emptyMap() +) { + init { + require(id.isNotBlank()) { "layer id cannot be blank" } + require(sourceNodeIds.isNotEmpty()) { "layer sourceNodeIds cannot be empty" } + require(sourceNodeIds.all { it.isNotBlank() }) { "layer sourceNodeIds cannot contain blanks" } + } + + public val hasBias: Boolean + get() = bias != null +} + +/** + * Backend intermediate produced after Minerva graph canonicalization. + */ +public data class MinervaIntermediate( + public val projectName: String, + public val target: MinervaTarget, + public val quantization: MinervaQuantization, + public val input: MinervaTensorRef, + public val output: MinervaTensorRef, + public val layers: List, + public val tensors: List, + public val metadata: Map = emptyMap() +) { + init { + require(projectName.isNotBlank()) { "projectName cannot be blank" } + require(layers.isNotEmpty()) { "MinervaIntermediate requires at least one layer" } + require(tensors.isNotEmpty()) { "MinervaIntermediate requires tensor references" } + } + + public val layerCount: Int + get() = layers.size + + public fun requireLowered(): MinervaIntermediate = this +} + diff --git a/skainet-compile/skainet-compile-minerva/src/commonTest/kotlin/sk/ainet/compile/minerva/MinervaExportFacadeTest.kt b/skainet-compile/skainet-compile-minerva/src/commonTest/kotlin/sk/ainet/compile/minerva/MinervaExportFacadeTest.kt index 0f04a963..45f40e64 100644 --- a/skainet-compile/skainet-compile-minerva/src/commonTest/kotlin/sk/ainet/compile/minerva/MinervaExportFacadeTest.kt +++ b/skainet-compile/skainet-compile-minerva/src/commonTest/kotlin/sk/ainet/compile/minerva/MinervaExportFacadeTest.kt @@ -17,6 +17,7 @@ class MinervaExportFacadeTest { val options = minervaTestOptions() assertEquals(MinervaExportBackend.backendName, facade.backendName) + assertEquals(MinervaExportBackend.backendName, facade.graphCanonicalizer.backendName) assertEquals(MinervaTarget.ATMEGA328P, options.target) assertEquals(MinervaQuantization.Q8, options.quantization) assertEquals("jvm-sequential-mlp-q8", options.toMetadata()["phaseOneScope"]) @@ -54,8 +55,11 @@ class MinervaExportFacadeTest { assertEquals(GraphExportStatus.FAILED, result.status) assertEquals(MinervaExportFailureKind.NOT_IMPLEMENTED, result.failure?.kind) assertEquals("minerva.export.not_implemented", result.failure?.code) + assertEquals("#693", result.failure?.details?.get("issue")) assertTrue(result.diagnostics.infos.any { it.code == "minerva.graph.validation.passed" }) + assertTrue(result.diagnostics.infos.any { it.code == "minerva.lowering.completed" }) assertTrue(result.compatibilityReport?.compatible == true) + assertEquals(1, result.intermediate?.layerCount) assertTrue(result.metadata["target"] == MinervaTarget.ATMEGA328P.compilerId) assertFailsWith { result.requireSuccess() @@ -69,6 +73,7 @@ class MinervaExportFacadeTest { assertEquals(MinervaExportFailureKind.NOT_IMPLEMENTED, result.failure?.kind) assertTrue(result.compatibilityReport?.compatible == true) + assertEquals(MinervaActivation.RELU, result.intermediate?.layers?.single()?.activation) } @Test @@ -116,4 +121,21 @@ class MinervaExportFacadeTest { ) assertEquals("conv", result.failure?.details?.get("nodeId")) } + + @Test + fun exportGraphCarriesLoweredIntermediateBeforeCompilerStage() { + val result = MinervaExportFacade().exportGraph( + graph = validMinervaMlpGraph(), + options = minervaTestOptions(projectName = "LoweredMlp") + ) + val intermediate = assertNotNull(result.intermediate) + + assertEquals(GraphExportStatus.FAILED, result.status) + assertEquals(MinervaExportFailureKind.NOT_IMPLEMENTED, result.failure?.kind) + assertEquals("LoweredMlp", intermediate.projectName) + assertEquals(MinervaTensorRole.INPUT, intermediate.input.role) + assertEquals(MinervaTensorRole.OUTPUT, intermediate.output.role) + assertEquals("matmul", intermediate.layers.single().id) + assertEquals("1", result.failure?.details?.get("layers")) + } } diff --git a/skainet-compile/skainet-compile-minerva/src/commonTest/kotlin/sk/ainet/compile/minerva/MinervaGraphCanonicalizerTest.kt b/skainet-compile/skainet-compile-minerva/src/commonTest/kotlin/sk/ainet/compile/minerva/MinervaGraphCanonicalizerTest.kt new file mode 100644 index 00000000..2c791946 --- /dev/null +++ b/skainet-compile/skainet-compile-minerva/src/commonTest/kotlin/sk/ainet/compile/minerva/MinervaGraphCanonicalizerTest.kt @@ -0,0 +1,78 @@ +package sk.ainet.compile.minerva + +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertTrue +import sk.ainet.compile.export.GraphExportContext +import sk.ainet.compile.export.GraphExportStage + +class MinervaGraphCanonicalizerTest { + + @Test + fun lowersSupportedMlpPatternIntoDenseLayerIr() { + val options = minervaTestOptions(projectName = "TinyMlp") + val context = GraphExportContext( + backendName = MinervaExportBackend.backendName, + targetName = options.projectName, + metadata = options.toMetadata() + ) + + val intermediate = MinervaGraphCanonicalizer().convert(validMinervaMlpGraph(), context) + val layer = intermediate.layers.single() + + assertEquals("TinyMlp", intermediate.projectName) + assertEquals(MinervaTarget.ATMEGA328P, intermediate.target) + assertEquals(MinervaQuantization.Q8, intermediate.quantization) + assertEquals(1, intermediate.layerCount) + assertEquals(MinervaLayerKind.DENSE, layer.kind) + assertEquals(MinervaTensorRole.INPUT, layer.input.role) + assertEquals(MinervaTensorRole.WEIGHT, layer.weights.role) + assertEquals(MinervaTensorRole.BIAS, layer.bias?.role) + assertEquals(MinervaTensorRole.OUTPUT, layer.output.role) + assertEquals(MinervaActivation.RELU, layer.activation) + assertEquals(listOf("matmul", "bias_add", "relu"), layer.sourceNodeIds) + assertEquals(listOf(1, 3), layer.output.shape) + assertTrue(layer.hasBias) + assertTrue(context.diagnostics.any { it.code == "minerva.lowering.started" }) + assertTrue(context.diagnostics.any { it.code == "minerva.lowering.completed" }) + } + + @Test + fun loweredIntermediateCollectsStableTensorRefs() { + val context = GraphExportContext( + backendName = MinervaExportBackend.backendName, + targetName = "TinyMlp", + metadata = minervaTestOptions().toMetadata() + ) + + val intermediate = MinervaGraphCanonicalizer().convert(validMinervaMlpGraph(), context) + + assertEquals(intermediate.input, intermediate.layers.first().input) + assertEquals(intermediate.output, intermediate.layers.last().output) + assertTrue(intermediate.tensors.any { it.id == "input_input_x" }) + assertTrue(intermediate.tensors.any { it.id == "weight_weight_w" }) + assertTrue(intermediate.tensors.any { it.id == "bias_bias_bias" }) + assertTrue(intermediate.tensors.any { it.id == "output_relu_y" }) + assertEquals(4, intermediate.input.elementCount) + assertEquals(3, intermediate.output.elementCount) + } + + @Test + fun unsupportedPatternFailsWithLoweringDiagnostic() { + val context = GraphExportContext(backendName = MinervaExportBackend.backendName) + + val exception = assertFailsWith { + MinervaGraphCanonicalizer().convert(activationBeforeLayerGraph(), context) + } + + assertEquals("minerva.lowering.no_layers", exception.code) + assertTrue( + context.diagnostics.any { + it.stage == GraphExportStage.LOWERING && + it.code == "minerva.lowering.no_layers" + } + ) + } +} +