Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ import sk.ainet.tape.Execution
public class MinervaExportFacade @kotlin.jvm.JvmOverloads constructor(
public val backendName: String = MinervaExportBackend.backendName,
public val compatibilityValidator: MinervaCompatibilityValidator = MinervaCompatibilityValidator(),
public val graphCanonicalizer: MinervaGraphCanonicalizer = MinervaGraphCanonicalizer()
public val graphCanonicalizer: MinervaGraphCanonicalizer = MinervaGraphCanonicalizer(),
public val npzWriter: MinervaNpzModelWriter = MinervaNpzModelWriter()
) {

/**
Expand Down Expand Up @@ -98,17 +99,31 @@ public class MinervaExportFacade @kotlin.jvm.JvmOverloads constructor(
return loweringFailedResult(options, context, compatibilityReport, exception)
}

val npzModel = try {
npzWriter.write(intermediate, context)
} catch (exception: MinervaNpzSchemaException) {
return npzSchemaFailedResult(
options = options,
context = context,
compatibilityReport = compatibilityReport,
intermediate = intermediate,
exception = exception
)
}

val failure = MinervaExportFailure(
kind = MinervaExportFailureKind.NOT_IMPLEMENTED,
stage = GraphExportStage.WRITING,
stage = GraphExportStage.PACKAGING,
code = "minerva.export.not_implemented",
message = "Minerva export lowered the graph to phase-one IR; compiler invocation, packaging, and verification are implemented in follow-up issues.",
message = "Minerva export lowered the graph and emitted the NPZ compiler input; compiler invocation, packaging, and verification are implemented in follow-up issues.",
details = mapOf(
"nextStep" to "Invoke the Minerva compiler and write the runtime project.",
"issue" to "#693",
"nextStep" to "Invoke libminerva compiler and package generated outputs.",
"issue" to "#694",
"layers" to intermediate.layerCount.toString(),
"input" to intermediate.input.id,
"output" to intermediate.output.id
"output" to intermediate.output.id,
"npzPath" to npzModel.logicalPath,
"npzBytes" to npzModel.bytes.size.toString()
)
)
context.error(
Expand All @@ -122,7 +137,8 @@ public class MinervaExportFacade @kotlin.jvm.JvmOverloads constructor(
context = context,
failure = failure,
compatibilityReport = compatibilityReport,
intermediate = intermediate
intermediate = intermediate,
npzModel = npzModel
)
}

Expand Down Expand Up @@ -223,12 +239,49 @@ public class MinervaExportFacade @kotlin.jvm.JvmOverloads constructor(
)
}

private fun npzSchemaFailedResult(
options: MinervaExportOptions,
context: GraphExportContext,
compatibilityReport: MinervaCompatibilityReport,
intermediate: MinervaIntermediate,
exception: MinervaNpzSchemaException
): MinervaExportResult {
val details = mutableMapOf(
"code" to exception.code,
"issue" to "#693"
)
exception.layerId?.let { details["layerId"] = it }
exception.arrayName?.let { details["arrayName"] = it }
details += exception.details
val failure = MinervaExportFailure(
kind = MinervaExportFailureKind.NPZ_SCHEMA_FAILED,
stage = GraphExportStage.WRITING,
code = exception.code,
message = exception.message ?: "Minerva NPZ schema validation failed.",
details = details
)
context.error(
stage = failure.stage,
code = failure.code,
message = failure.message,
details = failure.details
)
return failedResult(
options = options,
context = context,
failure = failure,
compatibilityReport = compatibilityReport,
intermediate = intermediate
)
}

private fun failedResult(
options: MinervaExportOptions,
context: GraphExportContext,
failure: MinervaExportFailure,
compatibilityReport: MinervaCompatibilityReport? = null,
intermediate: MinervaIntermediate? = null
intermediate: MinervaIntermediate? = null,
npzModel: MinervaNpzModel? = null
): MinervaExportResult {
return MinervaExportResult(
options = options,
Expand All @@ -238,7 +291,8 @@ public class MinervaExportFacade @kotlin.jvm.JvmOverloads constructor(
failure = failure,
metadata = context.metadata,
compatibilityReport = compatibilityReport,
intermediate = intermediate
intermediate = intermediate,
npzModel = npzModel
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ public enum class MinervaExportFailureKind {
GRAPH_VALIDATION_FAILED,
COMPATIBILITY_VALIDATION_FAILED,
LOWERING_FAILED,
NPZ_SCHEMA_FAILED,
NOT_IMPLEMENTED
}

Expand Down Expand Up @@ -203,7 +204,8 @@ public data class MinervaExportResult(
public val failure: MinervaExportFailure? = null,
public val metadata: Map<String, String> = emptyMap(),
public val compatibilityReport: MinervaCompatibilityReport? = null,
public val intermediate: MinervaIntermediate? = null
public val intermediate: MinervaIntermediate? = null,
public val npzModel: MinervaNpzModel? = null
) {
init {
require(status != GraphExportStatus.SUCCESS || bundle != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -337,11 +337,88 @@ public class MinervaGraphCanonicalizer @kotlin.jvm.JvmOverloads constructor(
dtype = spec.dtype,
role = role,
sourceNodeId = sourceNode.id,
values = tensorValues(spec, shape, sourceNode, context),
metadata = spec.metadata.mapValues { it.value.toString() }
)
}
}

private fun tensorValues(
spec: TensorSpec,
shape: List<Int>,
sourceNode: GraphNode,
context: GraphExportContext
): List<Float>? {
val elementCount = shape.fold(1) { acc, dim -> acc * dim }
val values = when (val rawValues = spec.metadata["values"]) {
null -> symbolicValues(spec, elementCount)
is FloatArray -> rawValues.toList()
is IntArray -> rawValues.map { it.toFloat() }
is List<*> -> rawValues.map { value ->
when (value) {
is Number -> value.toFloat()
else -> fail(
context = context,
code = "minerva.lowering.tensor_values_invalid",
message = "Tensor '${spec.name}' on node '${sourceNode.id}' has non-numeric initializer data.",
node = sourceNode,
details = mapOf("remediation" to "Use numeric FloatArray or IntArray initializer metadata.")
)
}
}
else -> fail(
context = context,
code = "minerva.lowering.tensor_values_invalid",
message = "Tensor '${spec.name}' on node '${sourceNode.id}' has unsupported initializer metadata.",
node = sourceNode,
details = mapOf(
"valuesType" to rawValues::class.simpleName.orEmpty(),
"remediation" to "Use numeric FloatArray or IntArray initializer metadata."
)
)
} ?: return null
if (values.size != elementCount) {
fail(
context = context,
code = "minerva.lowering.tensor_values_shape_mismatch",
message = "Tensor '${spec.name}' on node '${sourceNode.id}' initializer has ${values.size} value(s), expected $elementCount.",
node = sourceNode,
details = mapOf(
"actual" to values.size.toString(),
"expected" to elementCount.toString(),
"remediation" to "Match initializer data length to the tensor shape."
)
)
}
if (values.any { !it.isFinite() }) {
fail(
context = context,
code = "minerva.lowering.tensor_values_non_finite",
message = "Tensor '${spec.name}' on node '${sourceNode.id}' initializer contains non-finite values.",
node = sourceNode,
details = mapOf("remediation" to "Use finite numeric initializer values.")
)
}
return values
}

private fun symbolicValues(spec: TensorSpec, elementCount: Int): List<Float>? {
return when (val init = spec.metadata["init"]?.toString()) {
"zeros" -> List(elementCount) { 0.0f }
"ones" -> List(elementCount) { 1.0f }
null, "unspecified" -> null
else -> {
if (init.startsWith("full(") && init.endsWith(")")) {
val value = spec.metadata["value"] as? Number
?: init.removePrefix("full(").removeSuffix(")").toFloatOrNull()
if (value != null) List(elementCount) { value.toFloat() } else null
} else {
null
}
}
}
}

private fun tensorId(role: MinervaTensorRole, sourceNodeId: String, tensorName: String): String {
val cleanName = tensorName.replace(Regex("[^A-Za-z0-9_]+"), "_").ifBlank { "tensor" }
val cleanNode = sourceNodeId.replace(Regex("[^A-Za-z0-9_]+"), "_").ifBlank { "node" }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ public data class MinervaTensorRef(
public val dtype: String,
public val role: MinervaTensorRole,
public val sourceNodeId: String? = null,
public val values: List<Float>? = null,
public val metadata: Map<String, String> = emptyMap()
) {
init {
Expand All @@ -45,6 +46,12 @@ public data class MinervaTensorRef(
require(shape.isNotEmpty()) { "tensor shape cannot be empty" }
require(shape.all { it > 0 }) { "tensor shape dimensions must be positive" }
require(dtype.isNotBlank()) { "tensor dtype cannot be blank" }
require(values == null || values.size == elementCount) {
"tensor values must match tensor element count"
}
require(values == null || values.all { it.isFinite() }) {
"tensor values must be finite"
}
}

public val elementCount: Int
Expand Down Expand Up @@ -99,4 +106,3 @@ public data class MinervaIntermediate(

public fun requireLowered(): MinervaIntermediate = this
}

Loading
Loading