diff --git a/skainet-backends/skainet-backend-cpu/src/commonMain/kotlin/sk/ainet/exec/tensor/ops/DefaultCpuOps.kt b/skainet-backends/skainet-backend-cpu/src/commonMain/kotlin/sk/ainet/exec/tensor/ops/DefaultCpuOps.kt index 96d52d15..da71bed0 100644 --- a/skainet-backends/skainet-backend-cpu/src/commonMain/kotlin/sk/ainet/exec/tensor/ops/DefaultCpuOps.kt +++ b/skainet-backends/skainet-backend-cpu/src/commonMain/kotlin/sk/ainet/exec/tensor/ops/DefaultCpuOps.kt @@ -12,6 +12,10 @@ import sk.ainet.lang.tensor.data.FloatArrayTensorData import sk.ainet.lang.tensor.data.TensorDataFactory import sk.ainet.lang.tensor.ops.UpsampleMode import sk.ainet.lang.types.FP32 +import kotlin.math.ln +import kotlin.math.log10 as kmLog10 +import kotlin.math.log2 as kmLog2 +import kotlin.math.pow import kotlin.math.sqrt @Backend(id = "cpu", displayName = "CPU") @@ -2123,6 +2127,112 @@ public open class DefaultCpuOpsBase(protected val dataFactory: TensorDataFactory return newTensor(outData, tensor.dtype, tensor) } + /** + * Element-wise power: `c[i] = a[i] ^ b[i]`. Integer-valued exponents + * use repeated multiply for stability; everything else routes through + * `kotlin.math.pow`. Shape contract: shapes must match exactly (no + * broadcasting yet — caller's responsibility). + */ + override fun pow(a: Tensor, b: Tensor): Tensor { + require( + a.dtype == sk.ainet.lang.types.FP32::class || + a.dtype == sk.ainet.lang.types.FP16::class + ) { "pow supports only FP16/FP32, got ${a.dtype}" } + require(a.shape == b.shape) { "pow requires matching shapes; got ${a.shape} and ${b.shape}" } + val outData = dataFactory.init(a.shape, a.dtype) { idx -> + val av = a.data.get(*idx) as Float + val bv = b.data.get(*idx) as Float + @Suppress("UNCHECKED_CAST") + scalarPow(av, bv) as V + } + return newTensor(outData, a.dtype, a) + } + + /** + * Element-wise scalar power: `c[i] = a[i] ^ n`. Small-integer + * exponents (|n| <= 16) use repeated multiply for exactness; all + * other values route through `kotlin.math.pow`. + */ + override fun powScalar(a: Tensor, n: Number): Tensor { + require( + a.dtype == sk.ainet.lang.types.FP32::class || + a.dtype == sk.ainet.lang.types.FP16::class + ) { "powScalar supports only FP16/FP32, got ${a.dtype}" } + val nFloat = n.toFloat() + val nInt = n.toInt() + val isSmallInt = nFloat == nInt.toFloat() && kotlin.math.abs(nInt) <= 16 + val outData = dataFactory.init(a.shape, a.dtype) { idx -> + val av = a.data.get(*idx) as Float + @Suppress("UNCHECKED_CAST") + (if (isSmallInt) integerPow(av, nInt) else scalarPow(av, nFloat)) as V + } + return newTensor(outData, a.dtype, a) + } + + /** Repeated-multiply for small integer exponents. Handles n < 0 via reciprocal. */ + private fun integerPow(base: Float, n: Int): Float { + if (n == 0) return 1f + if (n < 0) return 1f / integerPow(base, -n) + var result = 1f + var b = base + var e = n + while (e > 0) { + if (e and 1 == 1) result *= b + b *= b + e = e ushr 1 + } + return result + } + + private fun scalarPow(base: Float, exp: Float): Float = + base.toDouble().pow(exp.toDouble()).toFloat() + + /** + * Element-wise natural log: `c[i] = ln(a[i])`. Negative or zero + * inputs follow `kotlin.math.ln` semantics (negative → NaN, zero + * → -Infinity). Mirror of `stablehlo.log`. + */ + override fun log(tensor: Tensor): Tensor { + require( + tensor.dtype == sk.ainet.lang.types.FP32::class || + tensor.dtype == sk.ainet.lang.types.FP16::class + ) { "log supports only FP16/FP32, got ${tensor.dtype}" } + val outData = dataFactory.init(tensor.shape, tensor.dtype) { idx -> + val v = tensor.data.get(*idx) as Float + @Suppress("UNCHECKED_CAST") + ln(v) as V + } + return newTensor(outData, tensor.dtype, tensor) + } + + /** Element-wise base-2 log: `c[i] = log2(a[i])`. */ + override fun log2(tensor: Tensor): Tensor { + require( + tensor.dtype == sk.ainet.lang.types.FP32::class || + tensor.dtype == sk.ainet.lang.types.FP16::class + ) { "log2 supports only FP16/FP32, got ${tensor.dtype}" } + val outData = dataFactory.init(tensor.shape, tensor.dtype) { idx -> + val v = tensor.data.get(*idx) as Float + @Suppress("UNCHECKED_CAST") + kmLog2(v) as V + } + return newTensor(outData, tensor.dtype, tensor) + } + + /** Element-wise base-10 log: `c[i] = log10(a[i])`. */ + override fun log10(tensor: Tensor): Tensor { + require( + tensor.dtype == sk.ainet.lang.types.FP32::class || + tensor.dtype == sk.ainet.lang.types.FP16::class + ) { "log10 supports only FP16/FP32, got ${tensor.dtype}" } + val outData = dataFactory.init(tensor.shape, tensor.dtype) { idx -> + val v = tensor.data.get(*idx) as Float + @Suppress("UNCHECKED_CAST") + kmLog10(v) as V + } + return newTensor(outData, tensor.dtype, tensor) + } + // ---- TinyFoA ops: abs, sign, clamp, lt, ge ---- @TensorOp() diff --git a/skainet-backends/skainet-backend-cpu/src/jvmTest/kotlin/sk/ainet/exec/tensor/ops/DefaultCpuOpsLogTest.kt b/skainet-backends/skainet-backend-cpu/src/jvmTest/kotlin/sk/ainet/exec/tensor/ops/DefaultCpuOpsLogTest.kt new file mode 100644 index 00000000..ccc1f778 --- /dev/null +++ b/skainet-backends/skainet-backend-cpu/src/jvmTest/kotlin/sk/ainet/exec/tensor/ops/DefaultCpuOpsLogTest.kt @@ -0,0 +1,98 @@ +package sk.ainet.exec.tensor.ops + +import kotlin.math.abs +import kotlin.math.ln +import kotlin.math.log10 as kmLog10 +import kotlin.math.log2 as kmLog2 +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertTrue +import sk.ainet.lang.tensor.Shape +import sk.ainet.lang.tensor.VoidOpsTensor +import sk.ainet.lang.tensor.data.DenseTensorDataFactory +import sk.ainet.lang.tensor.data.FloatArrayTensorData +import sk.ainet.lang.types.FP32 +import sk.ainet.lang.types.Int32 + +/** + * Forward-parity tests for the new `log`, `log2`, `log10` ops (Tier B + * of #617). Verifies against `kotlin.math.ln/log2/log10` per element, + * plus the dtype-restriction guard. + */ +class DefaultCpuOpsLogTest { + private val dataFactory = DenseTensorDataFactory() + private val ops = DefaultCpuOps(dataFactory) + + private fun floatTensor(shape: Shape, values: FloatArray) = + VoidOpsTensor(dataFactory.fromFloatArray(shape, FP32::class, values), FP32::class) + + private fun assertCloseTo(expected: FloatArray, actual: FloatArray, tol: Float = 1e-5f) { + assertEquals(expected.size, actual.size, "length mismatch") + for (i in expected.indices) { + val diff = abs(expected[i] - actual[i]) + assertTrue(diff <= tol, "[$i] expected=${expected[i]} actual=${actual[i]} diff=$diff tol=$tol") + } + } + + @Test + fun log_matches_kotlin_math_ln() { + val a = floatTensor(Shape(5), floatArrayOf(1f, 2f, kotlin.math.E.toFloat(), 10f, 100f)) + val expected = floatArrayOf(0f, ln(2f), 1f, ln(10f), ln(100f)) + val out = ops.log(a) + assertCloseTo(expected, (out.data as FloatArrayTensorData<*>).buffer) + } + + @Test + fun log2_matches_kotlin_math_log2() { + val a = floatTensor(Shape(5), floatArrayOf(1f, 2f, 4f, 8f, 1024f)) + val expected = floatArrayOf(0f, 1f, 2f, 3f, 10f) + val out = ops.log2(a) + assertCloseTo(expected, (out.data as FloatArrayTensorData<*>).buffer) + } + + @Test + fun log10_matches_kotlin_math_log10() { + val a = floatTensor(Shape(4), floatArrayOf(1f, 10f, 100f, 1000f)) + val expected = floatArrayOf(0f, 1f, 2f, 3f) + val out = ops.log10(a) + assertCloseTo(expected, (out.data as FloatArrayTensorData<*>).buffer) + } + + @Test + fun log_of_negative_returns_nan() { + val a = floatTensor(Shape(2), floatArrayOf(-1f, -2f)) + val out = ops.log(a) + for (v in (out.data as FloatArrayTensorData<*>).buffer) { + assertTrue(v.isNaN(), "log of negative must be NaN, got $v") + } + } + + @Test + fun log_of_zero_returns_negative_infinity() { + val a = floatTensor(Shape(1), floatArrayOf(0f)) + val out = ops.log(a) + val result = (out.data as FloatArrayTensorData<*>).buffer[0] + assertEquals(Float.NEGATIVE_INFINITY, result, "log(0) must be -Inf, got $result") + } + + @Test + fun log_log2_log10_consistent_with_each_other() { + // log_b(x) = ln(x) / ln(b) — verify the three flavours agree. + val a = floatTensor(Shape(3), floatArrayOf(2f, 10f, 100f)) + val logVals = (ops.log(a).data as FloatArrayTensorData<*>).buffer + val log2Vals = (ops.log2(a).data as FloatArrayTensorData<*>).buffer + val log10Vals = (ops.log10(a).data as FloatArrayTensorData<*>).buffer + for (i in 0..2) { + assertEquals(log2Vals[i], logVals[i] / ln(2f), 1e-5f, "log2 consistency at $i") + assertEquals(log10Vals[i], logVals[i] / ln(10f), 1e-5f, "log10 consistency at $i") + } + } + + @Test + fun log_rejects_non_float_dtype() { + val intData = dataFactory.fromIntArray(Shape(2), Int32::class, intArrayOf(1, 2)) + val tInt = VoidOpsTensor(intData, Int32::class) + assertFailsWith { ops.log(tInt) } + } +} diff --git a/skainet-backends/skainet-backend-cpu/src/jvmTest/kotlin/sk/ainet/exec/tensor/ops/DefaultCpuOpsPowTest.kt b/skainet-backends/skainet-backend-cpu/src/jvmTest/kotlin/sk/ainet/exec/tensor/ops/DefaultCpuOpsPowTest.kt new file mode 100644 index 00000000..e19dfea2 --- /dev/null +++ b/skainet-backends/skainet-backend-cpu/src/jvmTest/kotlin/sk/ainet/exec/tensor/ops/DefaultCpuOpsPowTest.kt @@ -0,0 +1,89 @@ +package sk.ainet.exec.tensor.ops + +import kotlin.math.abs +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertTrue +import sk.ainet.lang.tensor.Shape +import sk.ainet.lang.tensor.VoidOpsTensor +import sk.ainet.lang.tensor.data.DenseTensorDataFactory +import sk.ainet.lang.tensor.data.FloatArrayTensorData +import sk.ainet.lang.types.FP32 + +/** + * Forward-parity tests for the new `pow` and `powScalar` ops (Tier A + * of #617). Checks both the binary form (tensor exponent) and the + * scalar form for integer + real exponents. + */ +class DefaultCpuOpsPowTest { + private val dataFactory = DenseTensorDataFactory() + private val ops = DefaultCpuOps(dataFactory) + + private fun floatTensor(shape: Shape, values: FloatArray) = + VoidOpsTensor(dataFactory.fromFloatArray(shape, FP32::class, values), FP32::class) + + private fun assertCloseTo(expected: FloatArray, actual: FloatArray, tol: Float = 1e-4f) { + assertEquals(expected.size, actual.size, "length mismatch") + for (i in expected.indices) { + val diff = abs(expected[i] - actual[i]) + assertTrue(diff <= tol, "[$i] expected=${expected[i]} actual=${actual[i]} diff=$diff tol=$tol") + } + } + + @Test + fun powScalar_integer_2_matches_x_times_x() { + val a = floatTensor(Shape(5), floatArrayOf(0.5f, 1f, 2f, 3f, -2f)) + val expected = floatArrayOf(0.25f, 1f, 4f, 9f, 4f) + val out = ops.powScalar(a, 2) + assertCloseTo(expected, (out.data as FloatArrayTensorData<*>).buffer) + } + + @Test + fun powScalar_integer_3_matches_x_cubed() { + val a = floatTensor(Shape(4), floatArrayOf(1f, 2f, 3f, -2f)) + val expected = floatArrayOf(1f, 8f, 27f, -8f) + val out = ops.powScalar(a, 3) + assertCloseTo(expected, (out.data as FloatArrayTensorData<*>).buffer) + } + + @Test + fun powScalar_negative_integer_minus_1_is_reciprocal() { + val a = floatTensor(Shape(3), floatArrayOf(2f, 4f, 0.5f)) + val expected = floatArrayOf(0.5f, 0.25f, 2f) + val out = ops.powScalar(a, -1) + assertCloseTo(expected, (out.data as FloatArrayTensorData<*>).buffer) + } + + @Test + fun powScalar_real_half_is_sqrt() { + val a = floatTensor(Shape(4), floatArrayOf(0f, 1f, 4f, 9f)) + val expected = floatArrayOf(0f, 1f, 2f, 3f) + val out = ops.powScalar(a, 0.5f) + assertCloseTo(expected, (out.data as FloatArrayTensorData<*>).buffer) + } + + @Test + fun powScalar_real_1_5_matches_kotlin_math_pow() { + val a = floatTensor(Shape(3), floatArrayOf(1f, 2f, 4f)) + val expected = floatArrayOf(1f, 2.828427f, 8f) + val out = ops.powScalar(a, 1.5f) + assertCloseTo(expected, (out.data as FloatArrayTensorData<*>).buffer) + } + + @Test + fun pow_binary_element_wise() { + val a = floatTensor(Shape(4), floatArrayOf(2f, 3f, 4f, 5f)) + val b = floatTensor(Shape(4), floatArrayOf(2f, 3f, 0.5f, 1f)) + val expected = floatArrayOf(4f, 27f, 2f, 5f) + val out = ops.pow(a, b) + assertCloseTo(expected, (out.data as FloatArrayTensorData<*>).buffer) + } + + @Test + fun pow_binary_rejects_shape_mismatch() { + val a = floatTensor(Shape(3), floatArrayOf(1f, 2f, 3f)) + val b = floatTensor(Shape(4), floatArrayOf(1f, 2f, 3f, 4f)) + assertFailsWith { ops.pow(a, b) } + } +} diff --git a/skainet-compile/skainet-compile-core/src/commonMain/kotlin/sk/ainet/tape/RecordingExecution.kt b/skainet-compile/skainet-compile-core/src/commonMain/kotlin/sk/ainet/tape/RecordingExecution.kt index c201d513..02943357 100644 --- a/skainet-compile/skainet-compile-core/src/commonMain/kotlin/sk/ainet/tape/RecordingExecution.kt +++ b/skainet-compile/skainet-compile-core/src/commonMain/kotlin/sk/ainet/tape/RecordingExecution.kt @@ -184,6 +184,21 @@ internal class RecordingTensorOpsDecorator(private val base: TensorOps) : Tensor return out } + // --- Power ops --- + override fun pow(a: Tensor, b: Tensor): Tensor { + val out = base.pow(a, b) + record(PowOperation(), listOf(a, b), listOf(out)) + return out + } + + override fun powScalar(a: Tensor, n: Number): Tensor { + val out = base.powScalar(a, n) + // Single-input + scalar exponent stashed in parameters so the + // backward formula can recover it (a-partial is n * a^(n-1)). + record(PowOperation(parameters = mapOf("scalar_exponent" to n)), listOf(a), listOf(out)) + return out + } + // --- Scalar ops --- override fun addScalar(a: Tensor, b: Number): Tensor { val out = base.addScalar(a, b) @@ -426,6 +441,9 @@ internal class RecordingTensorOpsDecorator(private val base: TensorOps) : Tensor override fun mean(tensor: Tensor, dim: Int?): Tensor = base.mean(tensor, dim) override fun variance(tensor: Tensor, dim: Int?): Tensor = base.variance(tensor, dim) override fun sqrt(tensor: Tensor): Tensor = base.sqrt(tensor) + override fun log(tensor: Tensor): Tensor = base.log(tensor) + override fun log2(tensor: Tensor): Tensor = base.log2(tensor) + override fun log10(tensor: Tensor): Tensor = base.log10(tensor) override fun abs(tensor: Tensor): Tensor = base.abs(tensor) override fun sign(tensor: Tensor): Tensor = base.sign(tensor) override fun clamp(tensor: Tensor, minVal: Float, maxVal: Float): Tensor = base.clamp(tensor, minVal, maxVal) diff --git a/skainet-compile/skainet-compile-dag/src/commonMain/kotlin/sk/ainet/lang/graph/DefaultExecutionTape.kt b/skainet-compile/skainet-compile-dag/src/commonMain/kotlin/sk/ainet/lang/graph/DefaultExecutionTape.kt index 1e70aba1..eee0ab16 100644 --- a/skainet-compile/skainet-compile-dag/src/commonMain/kotlin/sk/ainet/lang/graph/DefaultExecutionTape.kt +++ b/skainet-compile/skainet-compile-dag/src/commonMain/kotlin/sk/ainet/lang/graph/DefaultExecutionTape.kt @@ -519,12 +519,26 @@ public class DefaultGradientTape( trace.inputs.mapNotNull { session.resolve(it) as? Tensor } } val out = outputs.firstOrNull() ?: return - + val anyInputRequiresGrad = inputs.any { it.requiresGrad } - // Propagate requiresGrad to output if any input requires it - if (anyInputRequiresGrad && !out.requiresGrad) { - (out as? Tensor)?.withRequiresGrad(true) + // Propagate requiresGrad to output if any input requires it. + // For multi-output ops (split) propagate to every chunk so the user + // can attach a loss to any of them. + if (anyInputRequiresGrad) { + outputs.forEach { o -> + if (!o.requiresGrad) (o as? Tensor)?.withRequiresGrad(true) + } + } + + // Special-case split: the standard "one backward per opTrace" shape + // doesn't fit because each chunk is its own tensor with its own + // upstream gradient. Register one backward op per chunk; each + // contributes a sparse input grad (zero everywhere except the slice + // it produced). Standard tape accumulation concats them naturally. + if (trace.opType == "split" && outputs.size > 1 && anyInputRequiresGrad) { + registerSplitBackwards(trace, inputs[0], outputs) + return } if (!out.requiresGrad) { @@ -535,6 +549,56 @@ public class DefaultGradientTape( backwardOps += backward } + private fun registerSplitBackwards( + trace: OpTrace, + input: Tensor, + outputs: List>, + ) { + val splitSize = (trace.attributes["splitSize"] as? Number)?.toInt() ?: return + val dim = (trace.attributes["dim"] as? Number)?.toInt() ?: 0 + outputs.forEachIndexed { chunkIndex, chunkOut -> + val offset = chunkIndex * splitSize + backwardOps += BackwardOp(listOf(input), chunkOut) { upstream -> + val grad = zerosLike(input) + scatterAlongDim(grad, upstream, dim, offset) + listOf?>(grad) + } + } + } + + /** + * Copy every element of [src] into [dest] at position `[offset, offset + src.shape[dim])` + * along [dim], leaving the rest of [dest] untouched. Used by the split + * backward to scatter each chunk's upstream gradient back into the right + * region of the input gradient. + */ + private fun scatterAlongDim( + dest: Tensor, + src: Tensor, + dim: Int, + offset: Int, + ) { + val destDims = dest.shape.dimensions + val srcDims = src.shape.dimensions + val rank = destDims.size + val destIdx = IntArray(rank) + val srcIdx = IntArray(rank) + fun walk(d: Int) { + if (d == rank) { + @Suppress("UNCHECKED_CAST") + dest.data.set(*destIdx, value = src.data.get(*srcIdx) as Any) + return + } + val len = if (d == dim) srcDims[dim] else destDims[d] + for (i in 0 until len) { + srcIdx[d] = i + destIdx[d] = if (d == dim) i + offset else i + walk(d + 1) + } + } + walk(0) + } + override fun addBackward(upstream: Tensor, output: Tensor, inputs: List>, attributes: Map): List?> = listOf(matchShape(upstream, inputs[0]), matchShape(upstream, inputs[1])) @@ -639,33 +703,120 @@ public class DefaultGradientTape( } override fun conv1dBackward(upstream: Tensor, output: Tensor, inputs: List>, attributes: Map): List?> { - // Conv1d backward is complex, return null for now - return listOf(null, null, null) + val input = inputs[0] + val weight = inputs[1] + val bias = inputs.getOrNull(2) + val stride = (attributes["stride"] as? Number)?.toInt() ?: 1 + val padding = (attributes["padding"] as? Number)?.toInt() ?: 0 + val dilation = (attributes["dilation"] as? Number)?.toInt() ?: 1 + val groups = (attributes["groups"] as? Number)?.toInt() ?: 1 + return conv1dGrads(upstream, input, weight, bias, stride, padding, dilation, groups) + } + + override fun powBackward(upstream: Tensor, output: Tensor, inputs: List>, attributes: Map): List?> { + // c = a^b + // ∂c/∂a = b * a^(b-1) * upstream + // ∂c/∂b = a^b * log(a) * upstream (note: log(a) is undefined for a <= 0) + val a = inputs[0] + val b = inputs[1] + val ops = a.ops + // ∂c/∂a = b * a^(b-1) * upstream + // Compute a^(b-1) via a^b / a = output / a (cheaper, reuses cached output). + val aPowBMinus1 = ops.divide(output, a) + val dA = ops.multiply(upstream, ops.multiply(b, aPowBMinus1)) + // ∂c/∂b = output * log(a) * upstream + val logA = ops.log(a) + val dB = ops.multiply(upstream, ops.multiply(output, logA)) + return listOf(dA, dB) + } + + override fun powScalarBackward(upstream: Tensor, output: Tensor, inputs: List>, attributes: Map): List?> { + // c = a^n (n is the scalar exponent — stashed by KSP as "n" String, + // by RecordingTensorOpsDecorator as "scalar_exponent" Number). + // ∂c/∂a = n * a^(n-1) * upstream + // n isn't differentiable — single-input op, single-output gradient. + val a = inputs[0] + val nRaw = attributes["n"] ?: attributes["scalar_exponent"] + ?: error("powScalarBackward requires attributes['n'] or ['scalar_exponent']; got attrs=$attributes") + val n = when (nRaw) { + is Number -> nRaw.toFloat() + is String -> nRaw.toFloat() + else -> error("powScalarBackward: unexpected exponent type ${nRaw::class}") + } + val ops = a.ops + // a^(n-1) — compute directly (cheaper than output / a which has a 0-divide + // hazard when a contains zeros and n > 0). + val aPowNMinus1 = ops.powScalar(a, n - 1f) + val dA = ops.mulScalar(ops.multiply(upstream, aPowNMinus1), n) + return listOf(dA) + } + + override fun logBackward(upstream: Tensor, output: Tensor, inputs: List>, attributes: Map): List?> { + // ∂log(a)/∂a = 1/a, so da = upstream / a. + val a = inputs[0] + return listOf(a.ops.divide(upstream, a)) + } + + override fun log2Backward(upstream: Tensor, output: Tensor, inputs: List>, attributes: Map): List?> { + // ∂log2(a)/∂a = 1/(a · ln 2), so da = upstream / (a · ln 2). + val a = inputs[0] + val ops = a.ops + val gradAOverA = ops.divide(upstream, a) + return listOf(ops.divScalar(gradAOverA, kotlin.math.ln(2.0))) + } + + override fun log10Backward(upstream: Tensor, output: Tensor, inputs: List>, attributes: Map): List?> { + // ∂log10(a)/∂a = 1/(a · ln 10). + val a = inputs[0] + val ops = a.ops + val gradAOverA = ops.divide(upstream, a) + return listOf(ops.divScalar(gradAOverA, kotlin.math.ln(10.0))) } override fun conv2dBackward(upstream: Tensor, output: Tensor, inputs: List>, attributes: Map): List?> { - // d(conv2d(x, w, b))/dx, d(conv2d(x, w, b))/dw, d(conv2d(x, w, b))/db - // This is complex and usually implemented in the backend. - // For now we return null to signal it's not implemented yet, or throw if we want to be strict. - return listOf(null, null, null) + val input = inputs[0] + val weight = inputs[1] + val bias = inputs.getOrNull(2) + val stride = pair2(attributes["stride"], 1) + val padding = pair2(attributes["padding"], 0) + val dilation = pair2(attributes["dilation"], 1) + val groups = (attributes["groups"] as? Number)?.toInt() ?: 1 + return conv2dGrads(upstream, input, weight, bias, stride, padding, dilation, groups) } override fun conv3dBackward(upstream: Tensor, output: Tensor, inputs: List>, attributes: Map): List?> { - // Conv3d backward is complex, return null for now - return listOf(null, null, null) + val input = inputs[0] + val weight = inputs[1] + val bias = inputs.getOrNull(2) + val stride = triple3(attributes["stride"], 1) + val padding = triple3(attributes["padding"], 0) + val dilation = triple3(attributes["dilation"], 1) + val groups = (attributes["groups"] as? Number)?.toInt() ?: 1 + return conv3dGrads(upstream, input, weight, bias, stride, padding, dilation, groups) } override fun maxPool2dBackward(upstream: Tensor, output: Tensor, inputs: List>, attributes: Map): List?> { - return listOf(null) + val input = inputs[0] + val kernel = pair2(attributes["kernelSize"], 1) + val stride = pair2(attributes["stride"], 1) + val padding = pair2(attributes["padding"], 0) + return listOf(maxPool2dGrad(upstream, input, kernel, stride, padding)) } override fun avgPool2dBackward(upstream: Tensor, output: Tensor, inputs: List>, attributes: Map): List?> { - // AvgPool2d backward is complex, return null for now - return listOf(null) + val input = inputs[0] + val kernel = pair2(attributes["kernelSize"], 1) + val stride = pair2(attributes["stride"], 1) + val padding = pair2(attributes["padding"], 0) + val countIncludePad = (attributes["countIncludePad"] as? Boolean) ?: true + return listOf(avgPool2dGrad(upstream, input, kernel, stride, padding, countIncludePad)) } override fun upsample2dBackward(upstream: Tensor, output: Tensor, inputs: List>, attributes: Map): List?> { - return listOf(null) + val input = inputs[0] + val scale = pair2(attributes["scale"], 1) + val mode = (attributes["mode"] as? String) ?: "Nearest" + return listOf(upsample2dGrad(upstream, input, scale, mode)) } override fun leakyReluBackward(upstream: Tensor, output: Tensor, inputs: List>, attributes: Map): List?> { @@ -714,9 +865,10 @@ public class DefaultGradientTape( } override fun splitBackward(upstream: Tensor, output: Tensor, inputs: List>, attributes: Map): List?> { - // splitBackward: d(split(x))/dx = concat(upstreams) - // Since each output of split is recorded separately, we need to accumulate them. - // This is not easily handled in the current tape. + // Unused: split's per-chunk backwards are registered directly by + // recordTrace via registerSplitBackwards, because a single + // BackwardOp(output=...) can't carry N upstream gradients. + // Kept here only to satisfy the DifferentiableTensorOps interface. return listOf(null) } override fun squeezeBackward(upstream: Tensor, output: Tensor, inputs: List>, attributes: Map): List?> = listOf(upstream.ops.unsqueeze(upstream, (attributes["dim"] as? Int) ?: 0)) // simplistic @@ -892,6 +1044,11 @@ public class DefaultGradientTape( "gelu" -> BackwardOp(inputs, output) { upstream -> geluBackward(upstream, output, inputs, trace.attributes) } "variance" -> BackwardOp(inputs, output) { upstream -> varianceBackward(upstream, output, inputs, trace.attributes) } "sqrt" -> BackwardOp(inputs, output) { upstream -> sqrtBackward(upstream, output, inputs, trace.attributes) } + "pow" -> BackwardOp(inputs, output) { upstream -> powBackward(upstream, output, inputs, trace.attributes) } + "powScalar" -> BackwardOp(inputs, output) { upstream -> powScalarBackward(upstream, output, inputs, trace.attributes) } + "log" -> BackwardOp(inputs, output) { upstream -> logBackward(upstream, output, inputs, trace.attributes) } + "log2" -> BackwardOp(inputs, output) { upstream -> log2Backward(upstream, output, inputs, trace.attributes) } + "log10" -> BackwardOp(inputs, output) { upstream -> log10Backward(upstream, output, inputs, trace.attributes) } "abs" -> BackwardOp(inputs, output) { upstream -> absBackward(upstream, output, inputs, trace.attributes) } "clamp" -> BackwardOp(inputs, output) { upstream -> clampBackward(upstream, output, inputs, trace.attributes) } "narrow" -> BackwardOp(inputs, output) { upstream -> narrowBackward(upstream, output, inputs, trace.attributes) } @@ -900,6 +1057,7 @@ public class DefaultGradientTape( "conv2d" -> BackwardOp(inputs, output) { upstream -> conv2dBackward(upstream, output, inputs, trace.attributes) } "conv3d" -> BackwardOp(inputs, output) { upstream -> conv3dBackward(upstream, output, inputs, trace.attributes) } "maxPool2d" -> BackwardOp(inputs, output) { upstream -> maxPool2dBackward(upstream, output, inputs, trace.attributes) } + "avgPool2d" -> BackwardOp(inputs, output) { upstream -> avgPool2dBackward(upstream, output, inputs, trace.attributes) } "upsample2d" -> BackwardOp(inputs, output) { upstream -> upsample2dBackward(upstream, output, inputs, trace.attributes) } "concat" -> BackwardOp(inputs, output) { upstream -> concatBackward(upstream, output, inputs, trace.attributes) } "split" -> BackwardOp(inputs, output) { upstream -> splitBackward(upstream, output, inputs, trace.attributes) } @@ -1217,6 +1375,425 @@ public class DefaultGradientTape( return gradOut } + /** + * Coerce a Pair-shaped attribute (KSP records pairs as `List` of size 2, + * the decorator may record them as `Pair`) to a `Pair`, + * falling back to (default, default). + */ + private fun pair2(raw: Any?, default: Int): Pair = when (raw) { + is List<*> -> { + val a = (raw.getOrNull(0) as? Number)?.toInt() ?: default + val b = (raw.getOrNull(1) as? Number)?.toInt() ?: default + a to b + } + is Pair<*, *> -> { + val a = (raw.first as? Number)?.toInt() ?: default + val b = (raw.second as? Number)?.toInt() ?: default + a to b + } + else -> default to default + } + + private fun triple3(raw: Any?, default: Int): Triple = when (raw) { + is List<*> -> Triple( + (raw.getOrNull(0) as? Number)?.toInt() ?: default, + (raw.getOrNull(1) as? Number)?.toInt() ?: default, + (raw.getOrNull(2) as? Number)?.toInt() ?: default, + ) + is Triple<*, *, *> -> Triple( + (raw.first as? Number)?.toInt() ?: default, + (raw.second as? Number)?.toInt() ?: default, + (raw.third as? Number)?.toInt() ?: default, + ) + else -> Triple(default, default, default) + } + + /** + * Direct CPU loops for conv2d backward. Correctness-first first-cut; perf + * follow-up tracked separately. Reuses the forward windowing formula + * (`ih = oh*sH - pH + kh*dH`) — the closed-form derivatives are: + * dInput[b, ic, ih, iw] += upstream[b, oc, oh, ow] * weight[oc, kc, kh, kw] + * dWeight[oc, kc, kh, kw] += upstream[b, oc, oh, ow] * input[b, ic, ih, iw] + * dBias[oc] += upstream[b, oc, oh, ow] + */ + private fun conv2dGrads( + upstream: Tensor, + input: Tensor, + weight: Tensor, + bias: Tensor?, + stride: Pair, + padding: Pair, + dilation: Pair, + groups: Int, + ): List?> { + val n = input.shape[0] + val cIn = input.shape[1] + val inH = input.shape[2] + val inW = input.shape[3] + val cOut = weight.shape[0] + val cInPerGroup = weight.shape[1] + val kH = weight.shape[2] + val kW = weight.shape[3] + val outH = upstream.shape[2] + val outW = upstream.shape[3] + val (sH, sW) = stride + val (pH, pW) = padding + val (dH, dW) = dilation + + val dInput = zerosLike(input) + val dWeight = zerosLike(weight) + val dBias = bias?.let { zerosLike(it) } + val biasRank = bias?.rank ?: 0 + + for (b in 0 until n) { + for (oc in 0 until cOut) { + val groupIdx = (oc * groups) / cOut + val inCStart = groupIdx * cInPerGroup + for (oh in 0 until outH) { + val hBase = oh * sH - pH + for (ow in 0 until outW) { + val wBase = ow * sW - pW + val gOut = (upstream.data.get(b, oc, oh, ow) as Number).toFloat() + if (dBias != null) { + val cur = when (biasRank) { + 1 -> (dBias.data.get(oc) as Number).toFloat() + 4 -> (dBias.data.get(0, oc, 0, 0) as Number).toFloat() + else -> 0f + } + val updated = cur + gOut + @Suppress("UNCHECKED_CAST") + when (biasRank) { + 1 -> dBias.data.set(oc, value = updated as Any) + 4 -> dBias.data.set(0, oc, 0, 0, value = updated as Any) + } + } + for (kc in 0 until cInPerGroup) { + val ic = inCStart + kc + for (kh in 0 until kH) { + val ih = hBase + kh * dH + if (ih !in 0 until inH) continue + for (kw in 0 until kW) { + val iw = wBase + kw * dW + if (iw !in 0 until inW) continue + val vIn = (input.data.get(b, ic, ih, iw) as Number).toFloat() + val vW = (weight.data.get(oc, kc, kh, kw) as Number).toFloat() + val curIn = (dInput.data.get(b, ic, ih, iw) as Number).toFloat() + @Suppress("UNCHECKED_CAST") + dInput.data.set(b, ic, ih, iw, value = (curIn + gOut * vW) as Any) + val curW = (dWeight.data.get(oc, kc, kh, kw) as Number).toFloat() + @Suppress("UNCHECKED_CAST") + dWeight.data.set(oc, kc, kh, kw, value = (curW + gOut * vIn) as Any) + } + } + } + } + } + } + } + return listOf(dInput, dWeight, dBias) + } + + /** conv1d backward — 1D analogue of conv2dGrads (input [N, C, L]). */ + private fun conv1dGrads( + upstream: Tensor, + input: Tensor, + weight: Tensor, + bias: Tensor?, + stride: Int, + padding: Int, + dilation: Int, + groups: Int, + ): List?> { + val n = input.shape[0] + val inL = input.shape[2] + val cOut = weight.shape[0] + val cInPerGroup = weight.shape[1] + val kL = weight.shape[2] + val outL = upstream.shape[2] + + val dInput = zerosLike(input) + val dWeight = zerosLike(weight) + val dBias = bias?.let { zerosLike(it) } + val biasRank = bias?.rank ?: 0 + + for (b in 0 until n) { + for (oc in 0 until cOut) { + val groupIdx = (oc * groups) / cOut + val inCStart = groupIdx * cInPerGroup + for (ol in 0 until outL) { + val lBase = ol * stride - padding + val gOut = (upstream.data.get(b, oc, ol) as Number).toFloat() + if (dBias != null) { + val cur = when (biasRank) { + 1 -> (dBias.data.get(oc) as Number).toFloat() + else -> 0f + } + @Suppress("UNCHECKED_CAST") + if (biasRank == 1) dBias.data.set(oc, value = (cur + gOut) as Any) + } + for (kc in 0 until cInPerGroup) { + val ic = inCStart + kc + for (kl in 0 until kL) { + val il = lBase + kl * dilation + if (il !in 0 until inL) continue + val vIn = (input.data.get(b, ic, il) as Number).toFloat() + val vW = (weight.data.get(oc, kc, kl) as Number).toFloat() + val curIn = (dInput.data.get(b, ic, il) as Number).toFloat() + @Suppress("UNCHECKED_CAST") + dInput.data.set(b, ic, il, value = (curIn + gOut * vW) as Any) + val curW = (dWeight.data.get(oc, kc, kl) as Number).toFloat() + @Suppress("UNCHECKED_CAST") + dWeight.data.set(oc, kc, kl, value = (curW + gOut * vIn) as Any) + } + } + } + } + } + return listOf(dInput, dWeight, dBias) + } + + /** conv3d backward — 3D analogue (input [N, C, D, H, W]). */ + private fun conv3dGrads( + upstream: Tensor, + input: Tensor, + weight: Tensor, + bias: Tensor?, + stride: Triple, + padding: Triple, + dilation: Triple, + groups: Int, + ): List?> { + val n = input.shape[0] + val inD = input.shape[2] + val inH = input.shape[3] + val inW = input.shape[4] + val cOut = weight.shape[0] + val cInPerGroup = weight.shape[1] + val kD = weight.shape[2] + val kH = weight.shape[3] + val kW = weight.shape[4] + val outD = upstream.shape[2] + val outH = upstream.shape[3] + val outW = upstream.shape[4] + val (sD, sH, sW) = stride + val (pD, pH, pW) = padding + val (eD, eH, eW) = dilation + + val dInput = zerosLike(input) + val dWeight = zerosLike(weight) + val dBias = bias?.let { zerosLike(it) } + val biasRank = bias?.rank ?: 0 + + for (b in 0 until n) { + for (oc in 0 until cOut) { + val groupIdx = (oc * groups) / cOut + val inCStart = groupIdx * cInPerGroup + for (od in 0 until outD) { + val dBaseI = od * sD - pD + for (oh in 0 until outH) { + val hBase = oh * sH - pH + for (ow in 0 until outW) { + val wBase = ow * sW - pW + val gOut = (upstream.data.get(b, oc, od, oh, ow) as Number).toFloat() + if (dBias != null && biasRank == 1) { + val cur = (dBias.data.get(oc) as Number).toFloat() + @Suppress("UNCHECKED_CAST") + dBias.data.set(oc, value = (cur + gOut) as Any) + } + for (kc in 0 until cInPerGroup) { + val ic = inCStart + kc + for (kd in 0 until kD) { + val id = dBaseI + kd * eD + if (id !in 0 until inD) continue + for (kh in 0 until kH) { + val ih = hBase + kh * eH + if (ih !in 0 until inH) continue + for (kw in 0 until kW) { + val iw = wBase + kw * eW + if (iw !in 0 until inW) continue + val vIn = (input.data.get(b, ic, id, ih, iw) as Number).toFloat() + val vW = (weight.data.get(oc, kc, kd, kh, kw) as Number).toFloat() + val curIn = (dInput.data.get(b, ic, id, ih, iw) as Number).toFloat() + @Suppress("UNCHECKED_CAST") + dInput.data.set(b, ic, id, ih, iw, value = (curIn + gOut * vW) as Any) + val curW = (dWeight.data.get(oc, kc, kd, kh, kw) as Number).toFloat() + @Suppress("UNCHECKED_CAST") + dWeight.data.set(oc, kc, kd, kh, kw, value = (curW + gOut * vIn) as Any) + } + } + } + } + } + } + } + } + } + return listOf(dInput, dWeight, dBias) + } + + /** + * maxPool2d backward — re-runs argmax over each window from the cached + * input and routes the upstream gradient to that single position. Ties + * resolved by taking the first encountered max (matches forward iteration + * order — kh outer, kw inner). + */ + private fun maxPool2dGrad( + upstream: Tensor, + input: Tensor, + kernel: Pair, + stride: Pair, + padding: Pair, + ): Tensor { + val n = input.shape[0] + val c = input.shape[1] + val inH = input.shape[2] + val inW = input.shape[3] + val (kH, kW) = kernel + val (sH, sW) = stride + val (pH, pW) = padding + val outH = upstream.shape[2] + val outW = upstream.shape[3] + val dInput = zerosLike(input) + + for (b in 0 until n) { + for (ch in 0 until c) { + for (oh in 0 until outH) { + val hBase = oh * sH - pH + for (ow in 0 until outW) { + val wBase = ow * sW - pW + var bestH = -1 + var bestW = -1 + var bestVal = Float.NEGATIVE_INFINITY + for (kh in 0 until kH) { + val ih = hBase + kh + if (ih !in 0 until inH) continue + for (kw in 0 until kW) { + val iw = wBase + kw + if (iw !in 0 until inW) continue + val v = (input.data.get(b, ch, ih, iw) as Number).toFloat() + if (v > bestVal) { + bestVal = v + bestH = ih + bestW = iw + } + } + } + if (bestH < 0) continue + val gOut = (upstream.data.get(b, ch, oh, ow) as Number).toFloat() + val cur = (dInput.data.get(b, ch, bestH, bestW) as Number).toFloat() + @Suppress("UNCHECKED_CAST") + dInput.data.set(b, ch, bestH, bestW, value = (cur + gOut) as Any) + } + } + } + } + return dInput + } + + /** + * avgPool2d backward — distribute each upstream element uniformly across + * the pooling window it came from. Divisor matches the forward rule: + * countIncludePad = true → always `kH * kW` + * countIncludePad = false → number of in-bounds positions per window + */ + private fun avgPool2dGrad( + upstream: Tensor, + input: Tensor, + kernel: Pair, + stride: Pair, + padding: Pair, + countIncludePad: Boolean, + ): Tensor { + val n = input.shape[0] + val c = input.shape[1] + val inH = input.shape[2] + val inW = input.shape[3] + val (kH, kW) = kernel + val (sH, sW) = stride + val (pH, pW) = padding + val outH = upstream.shape[2] + val outW = upstream.shape[3] + val dInput = zerosLike(input) + + for (b in 0 until n) { + for (ch in 0 until c) { + for (oh in 0 until outH) { + val hBase = oh * sH - pH + for (ow in 0 until outW) { + val wBase = ow * sW - pW + // Count valid positions when countIncludePad=false (matches forward). + var valid = 0 + for (kh in 0 until kH) { + val ih = hBase + kh + if (ih !in 0 until inH) continue + for (kw in 0 until kW) { + val iw = wBase + kw + if (iw in 0 until inW) valid++ + } + } + val divisor = if (countIncludePad) (kH * kW) else maxOf(valid, 1) + val gOut = (upstream.data.get(b, ch, oh, ow) as Number).toFloat() / divisor + for (kh in 0 until kH) { + val ih = hBase + kh + if (ih !in 0 until inH) continue + for (kw in 0 until kW) { + val iw = wBase + kw + if (iw !in 0 until inW) continue + val cur = (dInput.data.get(b, ch, ih, iw) as Number).toFloat() + @Suppress("UNCHECKED_CAST") + dInput.data.set(b, ch, ih, iw, value = (cur + gOut) as Any) + } + } + } + } + } + } + return dInput + } + + /** + * upsample2d backward (NEAREST only — the CPU forward only supports + * Nearest, so the backward mirrors that). For each input position, sum + * the upstream gradients of every output position it produced (the + * scaleH × scaleW block above-left of [ih*scaleH, iw*scaleW]). + */ + private fun upsample2dGrad( + upstream: Tensor, + input: Tensor, + scale: Pair, + mode: String, + ): Tensor { + require(mode.equals("Nearest", ignoreCase = true)) { + "upsample2dBackward: only Nearest mode implemented (got mode=$mode)" + } + val n = input.shape[0] + val c = input.shape[1] + val inH = input.shape[2] + val inW = input.shape[3] + val (scaleH, scaleW) = scale + val outH = upstream.shape[2] + val outW = upstream.shape[3] + val dInput = zerosLike(input) + + for (b in 0 until n) { + for (ch in 0 until c) { + for (oh in 0 until outH) { + val ih = oh / scaleH + if (ih !in 0 until inH) continue + for (ow in 0 until outW) { + val iw = ow / scaleW + if (iw !in 0 until inW) continue + val gOut = (upstream.data.get(b, ch, oh, ow) as Number).toFloat() + val cur = (dInput.data.get(b, ch, ih, iw) as Number).toFloat() + @Suppress("UNCHECKED_CAST") + dInput.data.set(b, ch, ih, iw, value = (cur + gOut) as Any) + } + } + } + } + return dInput + } + private fun clampGrad(upstream: Tensor, input: Tensor, minVal: Float, maxVal: Float): Tensor { val matchedUpstream = matchShape(upstream, input) val gradOut = zerosLike(input) diff --git a/skainet-compile/skainet-compile-dag/src/commonTest/kotlin/sk/ainet/compile/graph/ComputeGraphExecutorTest.kt b/skainet-compile/skainet-compile-dag/src/commonTest/kotlin/sk/ainet/compile/graph/ComputeGraphExecutorTest.kt index 63406465..124dcdba 100644 --- a/skainet-compile/skainet-compile-dag/src/commonTest/kotlin/sk/ainet/compile/graph/ComputeGraphExecutorTest.kt +++ b/skainet-compile/skainet-compile-dag/src/commonTest/kotlin/sk/ainet/compile/graph/ComputeGraphExecutorTest.kt @@ -177,6 +177,11 @@ private class TestTensorOps : TensorOps { override fun mean(tensor: Tensor, dim: Int?): Tensor = tensor override fun variance(tensor: Tensor, dim: Int?): Tensor = tensor override fun sqrt(tensor: Tensor): Tensor = tensor + override fun pow(a: Tensor, b: Tensor): Tensor = a + override fun powScalar(a: Tensor, n: Number): Tensor = a + override fun log(tensor: Tensor): Tensor = tensor + override fun log2(tensor: Tensor): Tensor = tensor + override fun log10(tensor: Tensor): Tensor = tensor override fun abs(tensor: Tensor): Tensor = tensor override fun sign(tensor: Tensor): Tensor = tensor override fun clamp(tensor: Tensor, minVal: Float, maxVal: Float): Tensor = tensor diff --git a/skainet-compile/skainet-compile-dag/src/commonTest/kotlin/sk/ainet/exec/autograd/CnnTrainingStepTest.kt b/skainet-compile/skainet-compile-dag/src/commonTest/kotlin/sk/ainet/exec/autograd/CnnTrainingStepTest.kt new file mode 100644 index 00000000..b0240bf5 --- /dev/null +++ b/skainet-compile/skainet-compile-dag/src/commonTest/kotlin/sk/ainet/exec/autograd/CnnTrainingStepTest.kt @@ -0,0 +1,145 @@ +package sk.ainet.exec.autograd + +import kotlin.test.Test +import kotlin.test.assertNotNull +import kotlin.test.assertTrue +import sk.ainet.context.Phase +import sk.ainet.exec.tensor.ops.DefaultCpuOps +import sk.ainet.lang.graph.DefaultComputeGraph +import sk.ainet.lang.graph.DefaultGradientTape +import sk.ainet.lang.graph.DefaultGraphExecutionContext +import sk.ainet.lang.nn.optim.sgd +import sk.ainet.lang.nn.topology.ModuleParameter +import sk.ainet.lang.tensor.Shape +import sk.ainet.lang.tensor.matmul +import sk.ainet.lang.tensor.relu +import sk.ainet.lang.tensor.withRequiresGrad +import sk.ainet.lang.trace.GraphSink +import sk.ainet.lang.types.FP32 + +/** + * End-to-end Tier D check: one SGD step on a tiny CNN exercises every + * conv/pool/upsample backward formula end-to-end. Confirms that: + * 1. The forward and backward graphs compose without dropping any op. + * 2. Every trainable parameter (conv weight, conv bias, linear weight, + * linear bias) receives a non-null gradient. + * 3. Loss decreases — or at least doesn't increase — after the optimiser + * applies the SGD update. + * + * Architecture (deliberately tiny for the in-process tape backend): + * input [1, 1, 4, 4] + * → conv2d(weight [2, 1, 2, 2], bias [2], stride 1, padding 0) → [1, 2, 3, 3] + * → relu → [1, 2, 3, 3] + * → maxPool2d(2, stride 1, pad 0) → [1, 2, 2, 2] + * → reshape [1, 8] + * → matmul(linW [8, 3]) + linB [1, 3] → [1, 3] + * → mse vs target [1, 3] + */ +class CnnTrainingStepTest { + + private fun trainCtx(): DefaultGraphExecutionContext { + val dataFactory = sk.ainet.lang.tensor.data.DenseTensorDataFactory() + val cpuOps = DefaultCpuOps(dataFactory) + val graph = DefaultComputeGraph() + return DefaultGraphExecutionContext( + baseOps = cpuOps, + phase = Phase.TRAIN, + tensorDataFactory = dataFactory, + createTapeFactory = { _ -> DefaultGradientTape(true) }, + computeGraph = graph, + baseSink = GraphSink(graph), + ) + } + + @Test + fun cnn_one_sgd_step_decreases_loss_and_populates_all_grads() { + val ctx = trainCtx() + + // Fixed input + target so the test is deterministic. + val input = ctx.fromFloatArray( + Shape(1, 1, 4, 4), FP32::class, + floatArrayOf( + 0.2f, 0.5f, -0.3f, 0.8f, + -0.4f, 0.1f, 0.6f, -0.7f, + 0.9f, -0.2f, 0.3f, 0.4f, + -0.1f, 0.7f, -0.5f, 0.6f, + ), + ) + val target = ctx.fromFloatArray( + Shape(1, 3), FP32::class, + floatArrayOf(1f, 0f, -1f), + ) + + // Trainable parameters. Values handpicked so the network actually + // computes something nontrivial (no all-zeros after ReLU). + val convW = ctx.fromFloatArray( + Shape(2, 1, 2, 2), FP32::class, + floatArrayOf( + 0.3f, -0.4f, 0.5f, 0.1f, // out-channel 0 + -0.2f, 0.6f, 0.1f, 0.3f, // out-channel 1 + ), + ).withRequiresGrad() + val convB = ctx.fromFloatArray( + Shape(2), FP32::class, floatArrayOf(0.05f, -0.05f), + ).withRequiresGrad() + val linW = ctx.fromFloatArray( + Shape(8, 3), FP32::class, + FloatArray(24) { (it % 7 - 3) * 0.1f }, + ).withRequiresGrad() + val linB = ctx.fromFloatArray( + Shape(1, 3), FP32::class, floatArrayOf(0.0f, 0.0f, 0.0f), + ).withRequiresGrad() + + val convWParam = ModuleParameter.WeightParameter("convW", convW) + val convBParam = ModuleParameter.BiasParameter("convB", convB) + val linWParam = ModuleParameter.WeightParameter("linW", linW) + val linBParam = ModuleParameter.BiasParameter("linB", linB) + + // Forward + record + backward in one block. + fun forward(): sk.ainet.lang.tensor.Tensor { + val conv = input.ops.conv2d( + input, convW, convB, + stride = 1 to 1, padding = 0 to 0, dilation = 1 to 1, groups = 1, + ) + val activated = conv.relu() + val pooled = activated.ops.maxPool2d( + activated, kernelSize = 2 to 2, stride = 1 to 1, padding = 0 to 0, + ) + val flat = pooled.ops.reshape(pooled, Shape(1, 8)) + val logits = flat.matmul(linW).ops.add(flat.matmul(linW), linB) + val diff = logits.ops.subtract(logits, target) + return logits.ops.sum(logits.ops.multiply(diff, diff)) + } + + // Baseline loss (eager — no need to record). + val initialLoss = forward().data.get() + + // Training step: record forward, populate gradients, step optimiser. + val pair = ctx.record { forward() } + val tape = pair.first as DefaultGradientTape + val loss = pair.second + tape.computeGradients( + targets = listOf(loss), + sources = listOf(convW, convB, linW, linB), + ) + + assertNotNull(convW.grad, "convW must have grad after backward") + assertNotNull(convB.grad, "convB must have grad after backward") + assertNotNull(linW.grad, "linW must have grad after backward") + assertNotNull(linB.grad, "linB must have grad after backward") + + val optimizer = sgd(lr = 0.01) + optimizer.addParameter(convWParam) + optimizer.addParameter(convBParam) + optimizer.addParameter(linWParam) + optimizer.addParameter(linBParam) + optimizer.step() + optimizer.zeroGrad() + + val updatedLoss = forward().data.get() + assertTrue( + updatedLoss <= initialLoss, + "loss should not increase after SGD step (initial=$initialLoss, after=$updatedLoss)", + ) + } +} diff --git a/skainet-compile/skainet-compile-dag/src/commonTest/kotlin/sk/ainet/exec/autograd/ConvPoolBackwardTest.kt b/skainet-compile/skainet-compile-dag/src/commonTest/kotlin/sk/ainet/exec/autograd/ConvPoolBackwardTest.kt new file mode 100644 index 00000000..f65e73cd --- /dev/null +++ b/skainet-compile/skainet-compile-dag/src/commonTest/kotlin/sk/ainet/exec/autograd/ConvPoolBackwardTest.kt @@ -0,0 +1,225 @@ +package sk.ainet.exec.autograd + +import kotlin.math.abs +import kotlin.test.Test +import kotlin.test.assertNotNull +import kotlin.test.assertTrue +import sk.ainet.context.Phase +import sk.ainet.exec.tensor.ops.DefaultCpuOps +import sk.ainet.lang.graph.DefaultComputeGraph +import sk.ainet.lang.graph.DefaultGradientTape +import sk.ainet.lang.graph.DefaultGraphExecutionContext +import sk.ainet.lang.tensor.Shape +import sk.ainet.lang.tensor.Tensor +import sk.ainet.lang.tensor.data.DenseTensorDataFactory +import sk.ainet.lang.tensor.data.FloatArrayTensorData +import sk.ainet.lang.tensor.ops.UpsampleMode +import sk.ainet.lang.tensor.withRequiresGrad +import sk.ainet.lang.trace.GraphSink +import sk.ainet.lang.types.FP32 + +/** + * Tier C parity tests for conv1d/2d/3d and pool/upsample backward formulas. + * Compares analytic gradient vs central finite difference on a small, fixed + * tensor — small enough to keep the O(window*kernel) brute-force backward + * cheap. FP32 tolerance is generous (3e-2) because conv accumulates many + * products and FP32 rounding adds up. + */ +class ConvPoolBackwardTest { + + private fun ctx(): DefaultGraphExecutionContext { + val dataFactory = DenseTensorDataFactory() + val cpuOps = DefaultCpuOps(dataFactory) + val graph = DefaultComputeGraph() + return DefaultGraphExecutionContext( + baseOps = cpuOps, + phase = Phase.TRAIN, + tensorDataFactory = dataFactory, + createTapeFactory = { _ -> DefaultGradientTape(true) }, + computeGraph = graph, + baseSink = GraphSink(graph), + ) + } + + private fun floatTensor(c: DefaultGraphExecutionContext, shape: Shape, values: FloatArray): Tensor = + c.fromFloatArray(shape, FP32::class, values) + + private fun buf(t: Tensor<*, *>): FloatArray = (t.data as FloatArrayTensorData<*>).buffer + + /** + * Compares analytic gradient w.r.t. `x` against central finite-difference + * for `f`. `f` builds a scalar (sum-reduced) output from `x` plus any + * captured constants. Output is sum-reduced inside the recording scope + * so the gradient corresponds to df/dx element-wise. + */ + private fun assertGradMatchesFiniteDiff( + xShape: Shape, + x0: FloatArray, + eps: Float = 1e-3f, + tol: Float = 3e-2f, + f: (DefaultGraphExecutionContext, Tensor) -> Tensor, + ) { + val ctx = ctx() + val x = floatTensor(ctx, xShape, x0.copyOf()).withRequiresGrad() + val pair = ctx.record { + val out = f(this, x) + out.ops.sum(out) + } + val sumOutput = pair.second + val tape = pair.first as DefaultGradientTape + tape.computeGradients(targets = listOf(sumOutput), sources = listOf(x)) + val analyticGrad = x.grad + assertNotNull(analyticGrad, "tape should populate x.grad") + val analytic = buf(analyticGrad) + + for (i in x0.indices) { + val xPlus = x0.copyOf().also { it[i] += eps } + val xMinus = x0.copyOf().also { it[i] -= eps } + val ctxPlus = ctx() + val ctxMinus = ctx() + val fPlus = buf(f(ctxPlus, floatTensor(ctxPlus, xShape, xPlus))).sumElems() + val fMinus = buf(f(ctxMinus, floatTensor(ctxMinus, xShape, xMinus))).sumElems() + val fdGrad = (fPlus - fMinus) / (2 * eps) + val diff = abs(analytic[i] - fdGrad) + assertTrue( + diff <= tol, + "[$i] analytic=${analytic[i]} fd=$fdGrad diff=$diff tol=$tol", + ) + } + } + + private fun FloatArray.sumElems(): Float { + var s = 0f + for (v in this) s += v + return s + } + + @Test + fun conv2d_backward_input_matches_finite_diff() { + // Input [1, 1, 4, 4], weight [2, 1, 2, 2], stride 1, padding 0. + val w = floatArrayOf( + 1f, 0f, 0f, -1f, // out-channel 0 + 0.5f, -0.5f, 1f, 1f, // out-channel 1 + ) + val bShape = Shape(2) + val bias = floatArrayOf(0.1f, -0.2f) + assertGradMatchesFiniteDiff( + xShape = Shape(1, 1, 4, 4), + x0 = floatArrayOf( + 0.5f, 1f, -0.3f, 0.8f, + -1f, 0.2f, 0.7f, -0.4f, + 0.1f, -0.9f, 0.6f, 0.3f, + 0.4f, 0.5f, -0.2f, 1f, + ), + ) { c, x -> + val wT = floatTensor(c, Shape(2, 1, 2, 2), w) + val bT = floatTensor(c, bShape, bias) + x.ops.conv2d(x, wT, bT, stride = 1 to 1, padding = 0 to 0, dilation = 1 to 1, groups = 1) + } + } + + @Test + fun conv2d_backward_input_with_stride_and_padding() { + val w = floatArrayOf(0.3f, -0.1f, 0.8f, 0.2f) + assertGradMatchesFiniteDiff( + xShape = Shape(1, 1, 5, 5), + x0 = FloatArray(25) { (it % 7 - 3).toFloat() * 0.2f }, + ) { c, x -> + val wT = floatTensor(c, Shape(1, 1, 2, 2), w) + x.ops.conv2d(x, wT, null, stride = 2 to 2, padding = 1 to 1, dilation = 1 to 1, groups = 1) + } + } + + @Test + fun conv1d_backward_input_matches_finite_diff() { + val w = floatArrayOf(0.5f, -1f, 0.2f, 1f, 0.3f, -0.4f) + assertGradMatchesFiniteDiff( + xShape = Shape(1, 2, 6), + x0 = FloatArray(12) { (it - 6) * 0.15f }, + ) { c, x -> + // weight [C_out=1, C_in=2, kL=3] + val wT = floatTensor(c, Shape(1, 2, 3), w) + x.ops.conv1d(x, wT, null, stride = 1, padding = 0, dilation = 1, groups = 1) + } + } + + @Test + fun conv3d_backward_input_matches_finite_diff() { + val w = floatArrayOf( + 0.5f, -0.3f, 1f, 0.2f, + -1f, 0.4f, 0.1f, 0.7f, + ) + assertGradMatchesFiniteDiff( + xShape = Shape(1, 1, 3, 3, 3), + x0 = FloatArray(27) { (it % 5 - 2) * 0.1f }, + ) { c, x -> + val wT = floatTensor(c, Shape(1, 1, 2, 2, 2), w) + x.ops.conv3d( + x, wT, null, + stride = Triple(1, 1, 1), + padding = Triple(0, 0, 0), + dilation = Triple(1, 1, 1), + groups = 1, + ) + } + } + + @Test + fun maxPool2d_backward_routes_to_argmax() { + // Distinct values per window — no ties. + val x0 = floatArrayOf( + 1f, 5f, 2f, 6f, + 3f, 7f, 4f, 8f, + 9f, 13f, 10f, 14f, + 11f, 15f, 12f, 16f, + ) + assertGradMatchesFiniteDiff( + xShape = Shape(1, 1, 4, 4), + x0 = x0, + eps = 1e-2f, // larger eps — argmax must not jump under perturbation + ) { _, x -> + x.ops.maxPool2d(x, kernelSize = 2 to 2, stride = 2 to 2, padding = 0 to 0) + } + } + + @Test + fun avgPool2d_backward_distributes_uniformly() { + assertGradMatchesFiniteDiff( + xShape = Shape(1, 1, 4, 4), + x0 = FloatArray(16) { it.toFloat() * 0.1f - 0.7f }, + ) { _, x -> + x.ops.avgPool2d( + x, + kernelSize = 2 to 2, stride = 2 to 2, padding = 0 to 0, + countIncludePad = true, + ) + } + } + + @Test + fun split_backward_accumulates_chunk_grads() { + // split a length-6 vector into three length-2 chunks, multiply each + // by a distinct scalar, sum the lot. Each input element's gradient + // should equal the scalar of the chunk it belongs to (2, 3, 5). + assertGradMatchesFiniteDiff( + xShape = Shape(6), + x0 = floatArrayOf(0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f), + ) { _, x -> + val chunks = x.ops.split(x, splitSize = 2, dim = 0) + val a = x.ops.mulScalar(chunks[0], 2f) + val b = x.ops.mulScalar(chunks[1], 3f) + val c = x.ops.mulScalar(chunks[2], 5f) + x.ops.add(x.ops.add(a, b), c) + } + } + + @Test + fun upsample2d_nearest_backward_sums_block() { + assertGradMatchesFiniteDiff( + xShape = Shape(1, 1, 3, 3), + x0 = FloatArray(9) { (it - 4) * 0.25f }, + ) { _, x -> + x.ops.upsample2d(x, scale = 2 to 2, mode = UpsampleMode.Nearest, alignCorners = false) + } + } +} diff --git a/skainet-compile/skainet-compile-dag/src/commonTest/kotlin/sk/ainet/exec/autograd/PowLogBackwardTest.kt b/skainet-compile/skainet-compile-dag/src/commonTest/kotlin/sk/ainet/exec/autograd/PowLogBackwardTest.kt new file mode 100644 index 00000000..ac299c7c --- /dev/null +++ b/skainet-compile/skainet-compile-dag/src/commonTest/kotlin/sk/ainet/exec/autograd/PowLogBackwardTest.kt @@ -0,0 +1,133 @@ +package sk.ainet.exec.autograd + +import kotlin.math.abs +import kotlin.test.Test +import kotlin.test.assertNotNull +import kotlin.test.assertTrue +import sk.ainet.context.Phase +import sk.ainet.exec.tensor.ops.DefaultCpuOps +import sk.ainet.lang.graph.DefaultComputeGraph +import sk.ainet.lang.graph.DefaultGradientTape +import sk.ainet.lang.graph.DefaultGraphExecutionContext +import sk.ainet.lang.tensor.Shape +import sk.ainet.lang.tensor.Tensor +import sk.ainet.lang.tensor.data.DenseTensorDataFactory +import sk.ainet.lang.tensor.data.FloatArrayTensorData +import sk.ainet.lang.tensor.pow +import sk.ainet.lang.tensor.log +import sk.ainet.lang.tensor.log2 +import sk.ainet.lang.tensor.log10 +import sk.ainet.lang.tensor.withRequiresGrad +import sk.ainet.lang.trace.GraphSink +import sk.ainet.lang.types.FP32 + +/** + * Tier C parity tests — every new backward formula compared to + * finite-difference (`(f(x+ε) - f(x-ε)) / 2ε`). Tolerance is + * deliberately generous (1e-2) to absorb FP32 noise; correctness, + * not precision. + */ +class PowLogBackwardTest { + + private fun ctx(): DefaultGraphExecutionContext { + val dataFactory = DenseTensorDataFactory() + val cpuOps = DefaultCpuOps(dataFactory) + val graph = DefaultComputeGraph() + return DefaultGraphExecutionContext( + baseOps = cpuOps, + phase = Phase.TRAIN, + tensorDataFactory = dataFactory, + createTapeFactory = { _ -> DefaultGradientTape(true) }, + computeGraph = graph, + baseSink = GraphSink(graph), + ) + } + + private fun floatTensor(c: DefaultGraphExecutionContext, values: FloatArray): Tensor = + c.fromFloatArray(Shape(values.size), FP32::class, values) + + private fun buf(t: Tensor<*, *>): FloatArray = (t.data as FloatArrayTensorData<*>).buffer + + /** + * Verify analytic gradient (from the tape) against the central + * finite-difference numerical gradient of [f] at each element of + * [x0]. Each element-wise partial is checked separately by perturbing + * that one element. Output is reduced to a scalar via sum-of-elements + * inside [f] so the resulting Jacobian-vector product matches a + * column of the Jacobian. + */ + private fun assertGradMatchesFiniteDiff( + x0: FloatArray, + eps: Float = 1e-3f, + tol: Float = 1e-2f, + f: (DefaultGraphExecutionContext, Tensor) -> Tensor, + ) { + // 1. Compute analytic grad via the tape. + val ctx = ctx() + val x = floatTensor(ctx, x0.copyOf()).withRequiresGrad() + val pair = ctx.record { + val out = f(this, x) + // Sum-reduce to a scalar so the gradient corresponds to df/dx + // (kept inside the record block so the sum itself is taped). + out.ops.sum(out) + } + val sumOutput = pair.second + val tape = pair.first as DefaultGradientTape + tape.computeGradients(targets = listOf(sumOutput), sources = listOf(x)) + val analyticGrad = x.grad + assertNotNull(analyticGrad, "tape should populate x.grad") + val analytic = buf(analyticGrad) + + // 2. Finite difference per-element. + for (i in x0.indices) { + val xPlus = x0.copyOf().also { it[i] += eps } + val xMinus = x0.copyOf().also { it[i] -= eps } + val ctxPlus = ctx() + val ctxMinus = ctx() + val fPlusOut = buf(f(ctxPlus, floatTensor(ctxPlus, xPlus))) + val fMinusOut = buf(f(ctxMinus, floatTensor(ctxMinus, xMinus))) + val fdGrad = (fPlusOut.sum() - fMinusOut.sum()) / (2 * eps) + val diff = abs(analytic[i] - fdGrad) + assertTrue( + diff <= tol, + "[$i] analytic=${analytic[i]} fd=$fdGrad diff=$diff tol=$tol (x0=${x0.toList()})", + ) + } + } + + private fun FloatArray.sum(): Float { + var s = 0f + for (v in this) s += v + return s + } + + @Test + fun powScalar_squared_backward_matches_finite_diff() { + assertGradMatchesFiniteDiff(floatArrayOf(0.5f, 1f, 1.5f, 2f, 3f)) { _, x -> x.pow(2) } + } + + @Test + fun powScalar_cubed_backward_matches_finite_diff() { + assertGradMatchesFiniteDiff(floatArrayOf(0.5f, 1f, 1.5f, 2f, 3f)) { _, x -> x.pow(3) } + } + + @Test + fun powScalar_real_exponent_1p5_backward_matches_finite_diff() { + assertGradMatchesFiniteDiff(floatArrayOf(0.5f, 1f, 2f, 4f)) { _, x -> x.pow(1.5f) } + } + + @Test + fun log_backward_matches_finite_diff() { + assertGradMatchesFiniteDiff(floatArrayOf(0.5f, 1f, 2f, 3f, 10f)) { _, x -> x.log() } + } + + @Test + fun log2_backward_matches_finite_diff() { + assertGradMatchesFiniteDiff(floatArrayOf(0.5f, 1f, 2f, 4f, 8f)) { _, x -> x.log2() } + } + + @Test + fun log10_backward_matches_finite_diff() { + assertGradMatchesFiniteDiff(floatArrayOf(1f, 10f, 100f)) { _, x -> x.log10() } + } +} diff --git a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/BasicMathConverter.kt b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/BasicMathConverter.kt index 436b398f..6c66b37b 100644 --- a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/BasicMathConverter.kt +++ b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/BasicMathConverter.kt @@ -22,7 +22,8 @@ public class BasicMathConverter : StableHloOperationConverter { override val supportedOperations: Set = setOf( "add", "subtract", "multiply", "divide", - "sub", "mul", "div" // Common aliases + "sub", "mul", "div", // Common aliases + "pow" ) override fun convert( @@ -101,6 +102,7 @@ public class BasicMathConverter : StableHloOperationConverter { "subtract", "sub" -> "stablehlo.subtract" "multiply", "mul" -> "stablehlo.multiply" "divide", "div" -> "stablehlo.divide" + "pow" -> "stablehlo.power" else -> null } } diff --git a/skainet-compile/skainet-compile-opt/src/commonMain/kotlin/sk/ainet/compile/opt/GraphOptimizationPipeline.kt b/skainet-compile/skainet-compile-opt/src/commonMain/kotlin/sk/ainet/compile/opt/GraphOptimizationPipeline.kt index aa212d9a..d77c7222 100644 --- a/skainet-compile/skainet-compile-opt/src/commonMain/kotlin/sk/ainet/compile/opt/GraphOptimizationPipeline.kt +++ b/skainet-compile/skainet-compile-opt/src/commonMain/kotlin/sk/ainet/compile/opt/GraphOptimizationPipeline.kt @@ -6,6 +6,7 @@ import sk.ainet.compile.opt.passes.DTypeConstraintResolutionPass import sk.ainet.compile.opt.passes.DeadCodeEliminationPass import sk.ainet.compile.opt.passes.LLMFusionPass import sk.ainet.compile.opt.passes.OperationFusionPass +import sk.ainet.compile.opt.passes.PowSpecializationPass import sk.ainet.compile.opt.passes.SharedWeightDeduplicationPass import sk.ainet.compile.opt.passes.TransposeEliminationPass @@ -80,6 +81,11 @@ public class GraphOptimizationPipeline( // is the boundary where dtype problems surface — every // later pass can assume dtype-validity. DTypeConstraintResolutionPass(), + // Rewrite pow(x, 2) to multiply(x, x) before fusion so + // the downstream passes see the multiply form. Runs after + // dtype resolution (still benefits from resolved dtypes) + // and before everything else. + PowSpecializationPass(), DeadCodeEliminationPass(), ConstantFoldingPass(), OperationFusionPass() @@ -92,6 +98,7 @@ public class GraphOptimizationPipeline( public fun createAggressive(): GraphOptimizationPipeline = GraphOptimizationPipeline( passes = listOf( DTypeConstraintResolutionPass(), + PowSpecializationPass(), DeadCodeEliminationPass(), ConstantFoldingPass(), OperationFusionPass() diff --git a/skainet-compile/skainet-compile-opt/src/commonMain/kotlin/sk/ainet/compile/opt/passes/PowSpecializationPass.kt b/skainet-compile/skainet-compile-opt/src/commonMain/kotlin/sk/ainet/compile/opt/passes/PowSpecializationPass.kt new file mode 100644 index 00000000..e455bcbc --- /dev/null +++ b/skainet-compile/skainet-compile-opt/src/commonMain/kotlin/sk/ainet/compile/opt/passes/PowSpecializationPass.kt @@ -0,0 +1,120 @@ +package sk.ainet.compile.opt.passes + +import sk.ainet.compile.opt.GraphOptimizationPass +import sk.ainet.compile.opt.GraphOptimizationResult +import sk.ainet.lang.graph.ComputeGraph +import sk.ainet.lang.graph.GraphEdge +import sk.ainet.lang.graph.GraphNode +import sk.ainet.lang.tensor.ops.MultiplyOperation +import sk.ainet.lang.tensor.ops.PowOperation + +/** + * Rewrites `powScalar(x, n)` for small integer `n` (currently `n == 2`) + * into the equivalent `multiply(x, x)` chain. The downstream multiply + * dispatch routes to the matmul / SIMD elementwise kernels — much + * cheaper than a real `pow` per element. + * + * Pattern detected: + * ``` + * PowOperation node with parameters["scalar_exponent"] == 2 and one input + * ``` + * Replaced with: + * ``` + * MultiplyOperation node with both inputs wired to the original input + * ``` + * + * Wider integer exponents (n = 3, 4, ...) intentionally not handled in + * this first cut — each adds one more layer of multiplies and the + * register-pressure / staging trade-off isn't obvious without a + * benchmark. Add them when there's a workload that wants them. + */ +public class PowSpecializationPass : GraphOptimizationPass { + + override val name: String = "pow-specialization" + + override fun apply(graph: ComputeGraph): GraphOptimizationResult { + val diagnostics = mutableListOf() + var changed = false + + // Snapshot nodes — we mutate the graph inside the loop. + val candidates = graph.nodes.filter { node -> + node.operation is PowOperation<*, *> && + node.inputs.size == 1 && + exponentInt(node) == 2 + } + + for (powNode in candidates) { + val producer = graph.edges.firstOrNull { it.destination.id == powNode.id } + ?: continue + val sourceNode = producer.source + + // Build the replacement multiply node — same id so consumer + // edges that target powNode.id continue to resolve. + val mul = GraphNode( + id = powNode.id, + operation = MultiplyOperation(), + inputs = listOf(powNode.inputs[0], powNode.inputs[0]), + outputs = powNode.outputs, + metadata = powNode.metadata, + ) + + // Snapshot edges before mutating. + val incomingToPow = graph.edges.filter { it.destination.id == powNode.id } + val outgoingFromPow = graph.edges.filter { it.source.id == powNode.id } + + graph.removeNode(powNode) + graph.addNode(mul) + + // Wire both multiply inputs to the original x. + for (i in 0..1) { + graph.addEdge( + GraphEdge( + id = "e_${sourceNode.id}_${producer.sourceOutputIndex}__${mul.id}_$i", + source = sourceNode, + destination = mul, + sourceOutputIndex = producer.sourceOutputIndex, + destinationInputIndex = i, + tensorSpec = producer.tensorSpec, + ), + ) + } + + // Restore the outgoing edges to the new node. + for (edge in outgoingFromPow) { + graph.addEdge( + GraphEdge( + id = edge.id, + source = mul, + destination = edge.destination, + sourceOutputIndex = edge.sourceOutputIndex, + destinationInputIndex = edge.destinationInputIndex, + tensorSpec = edge.tensorSpec, + ), + ) + } + + // The old incoming edge to the (removed) pow node should be + // cleaned up — removeNode usually does this, but defensively + // remove the producer edge if it survived. + for (edge in incomingToPow) { + graph.removeEdge(edge) + } + + diagnostics += "Specialized pow(${sourceNode.id}, 2) -> multiply at node ${powNode.id}" + changed = true + } + + return GraphOptimizationResult(graph, changed = changed, diagnostics = diagnostics) + } + + /** + * Returns the integer exponent stashed in [PowOperation.parameters] + * (under `"scalar_exponent"`), or `null` if absent / non-integer. + */ + private fun exponentInt(node: GraphNode): Int? { + val raw = node.operation.parameters["scalar_exponent"] ?: return null + val n = (raw as? Number)?.toDouble() ?: return null + val asInt = n.toInt() + return if (n == asInt.toDouble()) asInt else null + } +} diff --git a/skainet-compile/skainet-compile-opt/src/commonTest/kotlin/sk/ainet/compile/opt/PowSpecializationPassTest.kt b/skainet-compile/skainet-compile-opt/src/commonTest/kotlin/sk/ainet/compile/opt/PowSpecializationPassTest.kt new file mode 100644 index 00000000..d6018d06 --- /dev/null +++ b/skainet-compile/skainet-compile-opt/src/commonTest/kotlin/sk/ainet/compile/opt/PowSpecializationPassTest.kt @@ -0,0 +1,101 @@ +package sk.ainet.compile.opt + +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFalse +import kotlin.test.assertTrue +import sk.ainet.compile.opt.passes.PowSpecializationPass +import sk.ainet.lang.graph.DefaultComputeGraph +import sk.ainet.lang.graph.GraphEdge +import sk.ainet.lang.graph.GraphNode +import sk.ainet.lang.tensor.ops.GenericOperation +import sk.ainet.lang.tensor.ops.MultiplyOperation +import sk.ainet.lang.tensor.ops.PowOperation +import sk.ainet.lang.tensor.ops.TensorSpec + +class PowSpecializationPassTest { + + private fun spec(name: String = "t") = TensorSpec(name = name, shape = listOf(4), dtype = "Float32") + + @Test + fun rewrites_pow_x_2_to_multiply_x_x() { + val g = DefaultComputeGraph() + val input = g.addNode( + GraphNode(id = "x", operation = GenericOperation("input"), inputs = emptyList(), outputs = listOf(spec("x"))), + ) + val pow = g.addNode( + GraphNode( + id = "pow1", + operation = PowOperation( + parameters = mapOf("scalar_exponent" to 2), + ), + inputs = listOf(spec("x")), + outputs = listOf(spec("pow_out")), + ), + ) + g.addEdge(GraphEdge("e0", input, pow, tensorSpec = spec())) + + val result = PowSpecializationPass().apply(g) + assertTrue(result.changed, "pass must report changed=true") + + // The original pow node is gone; in its place there's a multiply + // with both inputs routed to x. + val mul = result.graph.nodes.firstOrNull { it.id == "pow1" } + assertTrue(mul != null && mul.operation is MultiplyOperation<*, *>, "node 'pow1' must now be a multiply") + val mulIncoming = result.graph.edges.filter { it.destination.id == "pow1" } + assertEquals(2, mulIncoming.size, "multiply must have two incoming edges") + assertTrue(mulIncoming.all { it.source.id == "x" }, "both multiply inputs route to x") + } + + @Test + fun leaves_pow_x_3_untouched_in_first_cut() { + // Tier A only specialises n=2; n=3 + higher are follow-ups. + val g = DefaultComputeGraph() + val input = g.addNode(GraphNode("x", GenericOperation("input"), emptyList(), listOf(spec("x")))) + g.addNode( + GraphNode( + "pow1", + PowOperation(parameters = mapOf("scalar_exponent" to 3)), + listOf(spec("x")), + listOf(spec("pow_out")), + ), + ) + g.addEdge(GraphEdge("e0", input, g.nodes.first { it.id == "pow1" }, tensorSpec = spec())) + + val result = PowSpecializationPass().apply(g) + assertFalse(result.changed, "pass must skip n != 2 in first cut") + val pow = result.graph.nodes.first { it.id == "pow1" } + assertTrue(pow.operation is PowOperation<*, *>, "node must remain a PowOperation") + } + + @Test + fun leaves_pow_binary_form_untouched() { + // PowOperation with two inputs (tensor exponent) is not a scalar-pow + // case — pass must ignore it. + val g = DefaultComputeGraph() + val a = g.addNode(GraphNode("a", GenericOperation("input"), emptyList(), listOf(spec("a")))) + val b = g.addNode(GraphNode("b", GenericOperation("input"), emptyList(), listOf(spec("b")))) + val pow = g.addNode( + GraphNode( + "pow1", + PowOperation(), + listOf(spec("a"), spec("b")), + listOf(spec("pow_out")), + ), + ) + g.addEdge(GraphEdge("e0", a, pow, destinationInputIndex = 0, tensorSpec = spec())) + g.addEdge(GraphEdge("e1", b, pow, destinationInputIndex = 1, tensorSpec = spec())) + + val result = PowSpecializationPass().apply(g) + assertFalse(result.changed, "pass must skip binary pow") + } + + @Test + fun leaves_graphs_without_pow_untouched() { + val g = DefaultComputeGraph() + g.addNode(GraphNode("x", GenericOperation("input"), emptyList(), listOf(spec("x")))) + g.addNode(GraphNode("relu", GenericOperation("relu"), listOf(spec("x")), listOf(spec("relu_out")))) + val result = PowSpecializationPass().apply(g) + assertFalse(result.changed) + } +} diff --git a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/TensorExtensions.kt b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/TensorExtensions.kt index 5cecc0a9..bba860c0 100644 --- a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/TensorExtensions.kt +++ b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/TensorExtensions.kt @@ -57,6 +57,16 @@ public operator fun Number.minus(t: Tensor): Tensor = public operator fun Number.times(t: Tensor): Tensor = t.ops.mulScalar(t, this) public operator fun Number.div(t: Tensor): Tensor = t.ops.rdivScalar(this, t) +// Power — element-wise. `tensor.pow(other)` for binary, `tensor.pow(n)` +// for scalar exponent. No operator form because Kotlin has no `**`. +public fun Tensor.pow(other: Tensor): Tensor = ops.pow(this, other) +public fun Tensor.pow(n: Number): Tensor = ops.powScalar(this, n) + +// Logarithms — element-wise. Backward formulas land in Tier C of #617. +public fun Tensor.log(): Tensor = ops.log(this) +public fun Tensor.log2(): Tensor = ops.log2(this) +public fun Tensor.log10(): Tensor = ops.log10(this) + // Additional convenience functions public fun Tensor.reshape(newShape: Shape): Tensor = ops.reshape(this, newShape) public fun Tensor.relu(): Tensor = ops.relu(this) diff --git a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/TensorOperations.kt b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/TensorOperations.kt index 90be21c3..03c51587 100644 --- a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/TensorOperations.kt +++ b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/TensorOperations.kt @@ -935,3 +935,129 @@ public class ScaledDotProductAttentionOperation( override fun clone(newParameters: Map): Operation = ScaledDotProductAttentionOperation(newParameters) } + +/** + * Element-wise power: `c[i] = a[i] ^ b[i]` (binary form) or + * `c[i] = a[i] ^ n` (scalar form — recorded with the same op class + * but a single input and the exponent stashed in + * `parameters["scalar_exponent"]` so the backward can recover it). + * + * The scalar-form distinction matters for autograd: the backward + * w.r.t. a scalar exponent is just `n * a^(n-1) * grad_out` and + * doesn't need `log`, while the binary form's exponent partial + * needs `a^b * log(a) * grad_out`. + */ +public class PowOperation( + parameters: Map = emptyMap() +) : BaseOperation("pow", "math", parameters) { + + override fun execute(inputs: List>): List> { + throw UnsupportedOperationException("Direct execution not supported in graph mode") + } + + override fun validateInputs(inputs: List): ValidationResult { + // Either binary (base + exponent tensor) or unary (base only, scalar exponent in params). + return when (inputs.size) { + 1 -> if (parameters.containsKey("scalar_exponent")) { + ValidationResult.Valid + } else { + ValidationResult.Invalid(listOf("Pow with one input requires parameters['scalar_exponent']")) + } + 2 -> ValidationResult.Valid + else -> ValidationResult.Invalid(listOf("Pow operation requires 1 (scalar) or 2 (tensor) inputs, got ${inputs.size}")) + } + } + + override fun inferOutputs(inputs: List): List { + require(inputs.isNotEmpty()) { "Pow operation requires at least 1 input" } + return listOf( + TensorSpec( + name = "pow_output", + shape = inputs[0].shape, + dtype = inputs[0].dtype, + requiresGrad = inputs.any { it.requiresGrad }, + ), + ) + } + + override fun clone(newParameters: Map): Operation = PowOperation(newParameters) +} + +/** + * Element-wise natural logarithm: `c[i] = ln(a[i])`. Mirror of + * `stablehlo.log`. Backward: `∂log(a)/∂a = grad_out / a` — formula + * lands in Tier C of #617. + */ +public class LogOperation( + parameters: Map = emptyMap(), +) : BaseOperation("log", "math", parameters) { + + override fun execute(inputs: List>): List> { + throw UnsupportedOperationException("Direct execution not supported in graph mode") + } + + override fun validateInputs(inputs: List): ValidationResult = + if (inputs.size == 1) ValidationResult.Valid + else ValidationResult.Invalid(listOf("Log operation requires exactly 1 input, got ${inputs.size}")) + + override fun inferOutputs(inputs: List): List = listOf( + TensorSpec( + name = "log_output", + shape = inputs[0].shape, + dtype = inputs[0].dtype, + requiresGrad = inputs[0].requiresGrad, + ), + ) + + override fun clone(newParameters: Map): Operation = LogOperation(newParameters) +} + +/** Element-wise base-2 logarithm: `c[i] = log2(a[i])`. */ +public class Log2Operation( + parameters: Map = emptyMap(), +) : BaseOperation("log2", "math", parameters) { + + override fun execute(inputs: List>): List> { + throw UnsupportedOperationException("Direct execution not supported in graph mode") + } + + override fun validateInputs(inputs: List): ValidationResult = + if (inputs.size == 1) ValidationResult.Valid + else ValidationResult.Invalid(listOf("Log2 operation requires exactly 1 input, got ${inputs.size}")) + + override fun inferOutputs(inputs: List): List = listOf( + TensorSpec( + name = "log2_output", + shape = inputs[0].shape, + dtype = inputs[0].dtype, + requiresGrad = inputs[0].requiresGrad, + ), + ) + + override fun clone(newParameters: Map): Operation = Log2Operation(newParameters) +} + +/** Element-wise base-10 logarithm: `c[i] = log10(a[i])`. */ +public class Log10Operation( + parameters: Map = emptyMap(), +) : BaseOperation("log10", "math", parameters) { + + override fun execute(inputs: List>): List> { + throw UnsupportedOperationException("Direct execution not supported in graph mode") + } + + override fun validateInputs(inputs: List): ValidationResult = + if (inputs.size == 1) ValidationResult.Valid + else ValidationResult.Invalid(listOf("Log10 operation requires exactly 1 input, got ${inputs.size}")) + + override fun inferOutputs(inputs: List): List = listOf( + TensorSpec( + name = "log10_output", + shape = inputs[0].shape, + dtype = inputs[0].dtype, + requiresGrad = inputs[0].requiresGrad, + ), + ) + + override fun clone(newParameters: Map): Operation = Log10Operation(newParameters) +} diff --git a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/TensorOps.kt b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/TensorOps.kt index aad15c67..6d2962e1 100644 --- a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/TensorOps.kt +++ b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/TensorOps.kt @@ -193,6 +193,32 @@ public interface TensorOps { @Diff public fun sqrt(tensor: Tensor): Tensor + /** + * Element-wise power with a tensor exponent: `c[i] = a[i] ^ b[i]`. + * Shape of [b] must broadcast against [a]. Mirror of `stablehlo.power`. + */ + @Diff + public fun pow(a: Tensor, b: Tensor): Tensor + + /** + * Element-wise power with a scalar exponent: `c[i] = a[i] ^ n`. + * Backward formula only differentiates w.r.t. [a]; [n] is a constant. + */ + @Diff + public fun powScalar(a: Tensor, n: Number): Tensor + + /** Element-wise natural logarithm: `c[i] = ln(a[i])`. Mirror of `stablehlo.log`. */ + @Diff + public fun log(tensor: Tensor): Tensor + + /** Element-wise base-2 logarithm: `c[i] = log2(a[i])`. */ + @Diff + public fun log2(tensor: Tensor): Tensor + + /** Element-wise base-10 logarithm: `c[i] = log10(a[i])`. */ + @Diff + public fun log10(tensor: Tensor): Tensor + /** Element-wise absolute value: |x| */ @Diff public fun abs(tensor: Tensor): Tensor diff --git a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/VoidTensorOps.kt b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/VoidTensorOps.kt index b4db5342..225af4b8 100644 --- a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/VoidTensorOps.kt +++ b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/VoidTensorOps.kt @@ -372,6 +372,32 @@ public class VoidTensorOps : TensorOps { return VoidOpsTensor(resultData, tensor.dtype) } + override fun pow(a: Tensor, b: Tensor): Tensor { + // Power preserves shape (broadcast assumed compatible). + val resultData = dataFactory.zeros(a.shape, a.dtype) + return VoidOpsTensor(resultData, a.dtype) + } + + override fun powScalar(a: Tensor, n: Number): Tensor { + val resultData = dataFactory.zeros(a.shape, a.dtype) + return VoidOpsTensor(resultData, a.dtype) + } + + override fun log(tensor: Tensor): Tensor { + val resultData = dataFactory.zeros(tensor.shape, tensor.dtype) + return VoidOpsTensor(resultData, tensor.dtype) + } + + override fun log2(tensor: Tensor): Tensor { + val resultData = dataFactory.zeros(tensor.shape, tensor.dtype) + return VoidOpsTensor(resultData, tensor.dtype) + } + + override fun log10(tensor: Tensor): Tensor { + val resultData = dataFactory.zeros(tensor.shape, tensor.dtype) + return VoidOpsTensor(resultData, tensor.dtype) + } + override fun abs(tensor: Tensor): Tensor { val resultData = dataFactory.zeros(tensor.shape, tensor.dtype) return VoidOpsTensor(resultData, tensor.dtype)