diff --git a/Cargo.lock b/Cargo.lock index 3cc58c4d..551bea36 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -657,9 +657,9 @@ checksum = "cd945864f07fe9f5371a27ad7b52a172b4b499999f1d97574c9fa68373937e12" [[package]] name = "log" -version = "0.4.27" +version = "0.4.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13dc2df351e3202783a1fe0d44375f7295ffb4049267b0f3018346dc122a1d94" +checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" [[package]] name = "mach2" @@ -826,9 +826,9 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.20.3" +version = "1.21.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "945462a4b81e43c4e3ba96bd7b49d834c6f61198356aa858733bc4acf3cbe62e" +checksum = "9f7c3e4beb33f85d45ae3e3a1792185706c8e16d043238c593331cc7cd313b50" [[package]] name = "openblas-build" @@ -927,18 +927,18 @@ checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" [[package]] name = "portable-atomic" -version = "1.11.0" +version = "1.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "350e9b48cbc6b0e028b0473b114454c6316e57336ee184ceab6e53f72c178b3e" +checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49" dependencies = [ "critical-section", ] [[package]] name = "portable-atomic-util" -version = "0.2.4" +version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d8a2f0d8d040d7848a709caf78912debcc3f33ee4b3cac47d73d1e1069e83507" +checksum = "091397be61a01d4be58e7841595bd4bfedb15f1cd54977d79b8271e94ed799a3" dependencies = [ "portable-atomic", ] @@ -954,9 +954,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.95" +version = "1.0.106" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02b3e5e68a3a1a02aad3ec490a98007cbc13c37cbe84a3cd7b8e406d76e7f778" +checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934" dependencies = [ "unicode-ident", ] @@ -972,9 +972,9 @@ dependencies = [ [[package]] name = "quote" -version = "1.0.40" +version = "1.0.45" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1885c039570dc00dcb4ff087a89e185fd56bae234ddc7f056a945bf36467248d" +checksum = "41f2619966050689382d2b44f664f4bc593e129785a36d6ee376ddf37259b924" dependencies = [ "proc-macro2", ] @@ -1247,18 +1247,27 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.219" +version = "1.0.228" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f0e2c6ed6606019b4e29e69dbaba95b11854410e5347d525002456dbbb786b6" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" +dependencies = [ + "serde_core", +] + +[[package]] +name = "serde_core" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.219" +version = "1.0.228" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" dependencies = [ "proc-macro2", "quote", @@ -1309,9 +1318,9 @@ checksum = "6ce2be8dc25455e1f91df71bfa12ad37d7af1092ae736f3a6cd0e37bc7810596" [[package]] name = "syn" -version = "2.0.101" +version = "2.0.117" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ce2b7fc941b3a24138a0a7cf8e858bfc6a992e7978a068a5c760deb0ed43caf" +checksum = "e665b8803e7b1d2a727f4023456bbbbe74da67099c585258af0ad9c5013b9b99" dependencies = [ "proc-macro2", "quote", @@ -1673,18 +1682,18 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.8.25" +version = "0.8.48" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1702d9583232ddb9174e01bb7c15a2ab8fb1bc6f227aa1233858c351a3ba0cb" +checksum = "eed437bf9d6692032087e337407a86f04cd8d6a16a37199ed57949d415bd68e9" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.8.25" +version = "0.8.48" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28a6e20d751156648aa063f3800b706ee209a32c0b4d9f24be3d980b01be55ef" +checksum = "70e3cd084b1788766f53af483dd21f93881ff30d7320490ec3ef7526d203bad4" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index b2d72da6..a696632b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -105,6 +105,11 @@ members = [ "ndarray-rand", "crates/*", ] +exclude = [ + # burn crate requires edition 2024 (Rust 1.85+) and pinned git deps. + # Built separately: cargo check --manifest-path crates/burn/Cargo.toml + "crates/burn", +] default-members = [ ".", "ndarray-rand", diff --git a/crates/burn/Cargo.lock b/crates/burn/Cargo.lock new file mode 100644 index 00000000..968eaa00 --- /dev/null +++ b/crates/burn/Cargo.lock @@ -0,0 +1,3320 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "addr2line" +version = "0.25.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b5d307320b3181d6d7954e663bd7c774a838b8220fe0593c86d9fb09f498b4b" +dependencies = [ + "gimli", +] + +[[package]] +name = "adler2" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" + +[[package]] +name = "aho-corasick" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ddd31a130427c27518df266943a5308ed92d4b226cc639f5a8f1002816174301" +dependencies = [ + "memchr", +] + +[[package]] +name = "allocator-api2" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" + +[[package]] +name = "android_system_properties" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" +dependencies = [ + "libc", +] + +[[package]] +name = "anyhow" +version = "1.0.102" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" + +[[package]] +name = "arrayref" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76a2e8124351fda1ef8aaaa3bbd7ebbcb486bbcd4225aca0aa0d84bb2db8fecb" + +[[package]] +name = "arrayvec" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" + +[[package]] +name = "ash" +version = "0.38.0+1.3.281" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bb44936d800fea8f016d7f2311c6a4f97aebd5dc86f09906139ec848cf3a46f" +dependencies = [ + "libloading 0.8.9", +] + +[[package]] +name = "async-channel" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "924ed96dd52d1b75e9c1a3e6275715fd320f5f9439fb5a4a11fa51f4221158d2" +dependencies = [ + "concurrent-queue", + "event-listener-strategy", + "futures-core", + "pin-project-lite", +] + +[[package]] +name = "atomic_float" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "628d228f918ac3b82fe590352cc719d30664a0c13ca3a60266fe02c7132d480a" + +[[package]] +name = "autocfg" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" + +[[package]] +name = "backtrace" +version = "0.3.76" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb531853791a215d7c62a30daf0dde835f381ab5de4589cfe7c649d2cbe92bd6" +dependencies = [ + "addr2line", + "cfg-if", + "libc", + "miniz_oxide", + "object", + "rustc-demangle", + "windows-link", +] + +[[package]] +name = "base64" +version = "0.22.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" + +[[package]] +name = "base64ct" +version = "1.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2af50177e190e07a26ab74f8b1efbfe2ef87da2116221318cb1c2e82baf7de06" + +[[package]] +name = "bincode" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "36eaf5d7b090263e8150820482d5d93cd964a81e4019913c972f4edcc6edb740" +dependencies = [ + "serde", + "unty", +] + +[[package]] +name = "bit-set" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34ddef2995421ab6a5c779542c81ee77c115206f4ad9d5a8e05f4ff49716a3dd" +dependencies = [ + "bit-vec", +] + +[[package]] +name = "bit-vec" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b71798fca2c1fe1086445a7258a4bc81e6e49dcd24c8d0dd9a1e57395b603f51" + +[[package]] +name = "bitflags" +version = "2.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "843867be96c8daad0d758b57df9392b6d8d271134fce549de6ce169ff98a92af" +dependencies = [ + "serde_core", +] + +[[package]] +name = "blake3" +version = "1.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2468ef7d57b3fb7e16b576e8377cdbde2320c60e1491e961d11da40fc4f02a2d" +dependencies = [ + "arrayref", + "arrayvec", + "cc", + "cfg-if", + "constant_time_eq", + "cpufeatures 0.2.17", +] + +[[package]] +name = "blas-src" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b95e83dc868db96e69795c0213143095f03de9dd3252f205d4ac716e4076a7e0" +dependencies = [ + "netlib-src", + "openblas-src", +] + +[[package]] +name = "block2" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cdeb9d870516001442e364c5220d3574d2da8dc765554b4a617230d33fa58ef5" +dependencies = [ + "objc2", +] + +[[package]] +name = "bumpalo" +version = "3.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d20789868f4b01b2f2caec9f5c4e0213b41e3e5702a50157d699ae31ced2fcb" + +[[package]] +name = "burn" +version = "0.1.0" +dependencies = [ + "atomic_float", + "blas-src", + "burn-backend", + "burn-ir", + "burn-std", + "bytemuck", + "bytes", + "const-random", + "itertools", + "libm", + "macerator", + "matrixmultiply", + "ndarray", + "num-traits", + "openblas-src", + "paste", + "rand", + "rayon", + "seq-macro", + "serde", +] + +[[package]] +name = "burn-backend" +version = "0.21.0-pre.2" +source = "git+https://github.com/tracel-ai/burn.git?rev=ed72d2b#ed72d2b125a364aff18aed2a53396c128e01cb42" +dependencies = [ + "burn-std", + "bytemuck", + "cubecl", + "derive-new", + "enumset", + "hashbrown 0.16.1", + "num-traits", + "portable-atomic-util", + "rand", + "rand_distr", + "serde", + "spin", + "thiserror", +] + +[[package]] +name = "burn-ir" +version = "0.21.0-pre.2" +source = "git+https://github.com/tracel-ai/burn.git?rev=ed72d2b#ed72d2b125a364aff18aed2a53396c128e01cb42" +dependencies = [ + "burn-backend", + "hashbrown 0.16.1", + "serde", +] + +[[package]] +name = "burn-std" +version = "0.21.0-pre.2" +source = "git+https://github.com/tracel-ai/burn.git?rev=ed72d2b#ed72d2b125a364aff18aed2a53396c128e01cb42" +dependencies = [ + "bytemuck", + "bytes", + "cubecl-common", + "cubecl-zspace", + "half", + "num-traits", + "serde", + "smallvec", +] + +[[package]] +name = "bytemuck" +version = "1.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8efb64bd706a16a1bdde310ae86b351e4d21550d98d056f22f8a7f7a2183fec" +dependencies = [ + "bytemuck_derive", +] + +[[package]] +name = "bytemuck_derive" +version = "1.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9abbd1bc6865053c427f7198e6af43bfdedc55ab791faed4fbd361d789575ff" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "bytes" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" +dependencies = [ + "portable-atomic", +] + +[[package]] +name = "cblas-sys" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6feecd82cce51b0204cf063f0041d69f24ce83f680d87514b004248e7b0fa65" +dependencies = [ + "libc", +] + +[[package]] +name = "cc" +version = "1.2.58" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1e928d4b69e3077709075a938a05ffbedfa53a84c8f766efbf8220bb1ff60e1" +dependencies = [ + "find-msvc-tools", + "shlex", +] + +[[package]] +name = "cfg-if" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" + +[[package]] +name = "cfg_aliases" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" + +[[package]] +name = "chacha20" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f8d983286843e49675a4b7a2d174efe136dc93a18d69130dd18198a6c167601" +dependencies = [ + "cfg-if", + "cpufeatures 0.3.0", + "rand_core", +] + +[[package]] +name = "cmake" +version = "0.1.58" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c0f78a02292a74a88ac736019ab962ece0bc380e3f977bf72e376c5d78ff0678" +dependencies = [ + "cc", +] + +[[package]] +name = "codespan-reporting" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af491d569909a7e4dee0ad7db7f5341fef5c614d5b8ec8cf765732aba3cff681" +dependencies = [ + "serde", + "termcolor", + "unicode-width", +] + +[[package]] +name = "concurrent-queue" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ca0197aee26d1ae37445ee532fefce43251d24cc7c166799f4d46817f1d3973" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "const-random" +version = "0.1.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87e00182fe74b066627d63b85fd550ac2998d4b0bd86bfed477a0ae4c7c71359" +dependencies = [ + "const-random-macro", +] + +[[package]] +name = "const-random-macro" +version = "0.1.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9d839f2a20b0aee515dc581a6172f2321f96cab76c1a38a4c584a194955390e" +dependencies = [ + "getrandom 0.2.17", + "once_cell", + "tiny-keccak", +] + +[[package]] +name = "constant_time_eq" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d52eff69cd5e647efe296129160853a42795992097e8af39800e1060caeea9b" + +[[package]] +name = "convert_case" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "633458d4ef8c78b72454de2d54fd6ab2e60f9e02be22f3c6104cdc8a4e0fceb9" +dependencies = [ + "unicode-segmentation", +] + +[[package]] +name = "core-foundation" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2a6cd9ae233e7f62ba4e9353e81a88df7fc8a5987b8d445b4d90c879bd156f6" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "core-foundation-sys" +version = "0.8.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" + +[[package]] +name = "cpufeatures" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" +dependencies = [ + "libc", +] + +[[package]] +name = "cpufeatures" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b2a41393f66f16b0823bb79094d54ac5fbd34ab292ddafb9a0456ac9f87d201" +dependencies = [ + "libc", +] + +[[package]] +name = "crc32fast" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9481c1c90cbf2ac953f07c8d4a58aa3945c425b7185c9154d67a65e4230da511" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "critical-section" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "790eea4361631c5e7d22598ecd5723ff611904e3344ce8720784c93e3d83d40b" + +[[package]] +name = "crossbeam-channel" +version = "0.5.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82b8f8f868b36967f9606790d1903570de9ceaf870a7bf9fbbd3016d636a2cb2" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-deque" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" + +[[package]] +name = "crunchy" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "460fbee9c2c2f33933d720630a6a0bac33ba7053db5344fac858d4b8952d77d5" + +[[package]] +name = "cubecl" +version = "0.10.0-pre.2" +source = "git+https://github.com/tracel-ai/cubecl?rev=5b831a3cfac3eca0065fe0dbf57cddf5946d1586#5b831a3cfac3eca0065fe0dbf57cddf5946d1586" +dependencies = [ + "cubecl-core", + "cubecl-cuda", + "cubecl-ir", + "cubecl-runtime", + "cubecl-wgpu", + "half", +] + +[[package]] +name = "cubecl-common" +version = "0.10.0-pre.2" +source = "git+https://github.com/tracel-ai/cubecl?rev=5b831a3cfac3eca0065fe0dbf57cddf5946d1586#5b831a3cfac3eca0065fe0dbf57cddf5946d1586" +dependencies = [ + "backtrace", + "bincode", + "bytemuck", + "bytes", + "cfg-if", + "cfg_aliases", + "derive-new", + "derive_more", + "dirs", + "embassy-futures", + "embassy-time", + "float4", + "float8", + "futures-lite", + "half", + "hashbrown 0.16.1", + "log", + "num-traits", + "oneshot", + "parking_lot", + "portable-atomic", + "portable-atomic-util", + "rand", + "sanitize-filename", + "serde", + "serde_bytes", + "serde_json", + "spin", + "tynm", + "wasm-bindgen-futures", + "web-time", + "xxhash-rust", +] + +[[package]] +name = "cubecl-core" +version = "0.10.0-pre.2" +source = "git+https://github.com/tracel-ai/cubecl?rev=5b831a3cfac3eca0065fe0dbf57cddf5946d1586#5b831a3cfac3eca0065fe0dbf57cddf5946d1586" +dependencies = [ + "bitflags", + "bytemuck", + "cubecl-common", + "cubecl-ir", + "cubecl-macros", + "cubecl-runtime", + "cubecl-zspace", + "derive-new", + "derive_more", + "enumset", + "float-ord", + "half", + "hashbrown 0.16.1", + "log", + "num-traits", + "paste", + "serde", + "serde_json", + "variadics_please", +] + +[[package]] +name = "cubecl-cpp" +version = "0.10.0-pre.2" +source = "git+https://github.com/tracel-ai/cubecl?rev=5b831a3cfac3eca0065fe0dbf57cddf5946d1586#5b831a3cfac3eca0065fe0dbf57cddf5946d1586" +dependencies = [ + "bytemuck", + "cubecl-common", + "cubecl-core", + "cubecl-opt", + "cubecl-runtime", + "derive-new", + "half", + "itertools", + "log", +] + +[[package]] +name = "cubecl-cuda" +version = "0.10.0-pre.2" +source = "git+https://github.com/tracel-ai/cubecl?rev=5b831a3cfac3eca0065fe0dbf57cddf5946d1586#5b831a3cfac3eca0065fe0dbf57cddf5946d1586" +dependencies = [ + "bytemuck", + "cubecl-common", + "cubecl-core", + "cubecl-cpp", + "cubecl-runtime", + "cudarc", + "derive-new", + "half", + "log", + "serde", +] + +[[package]] +name = "cubecl-ir" +version = "0.10.0-pre.2" +source = "git+https://github.com/tracel-ai/cubecl?rev=5b831a3cfac3eca0065fe0dbf57cddf5946d1586#5b831a3cfac3eca0065fe0dbf57cddf5946d1586" +dependencies = [ + "cubecl-common", + "cubecl-macros-internal", + "derive-new", + "derive_more", + "enumset", + "float-ord", + "fnv", + "foldhash 0.2.0", + "half", + "hashbrown 0.16.1", + "num-traits", + "portable-atomic", + "serde", + "variadics_please", +] + +[[package]] +name = "cubecl-macros" +version = "0.10.0-pre.2" +source = "git+https://github.com/tracel-ai/cubecl?rev=5b831a3cfac3eca0065fe0dbf57cddf5946d1586#5b831a3cfac3eca0065fe0dbf57cddf5946d1586" +dependencies = [ + "cubecl-common", + "darling 0.23.0", + "derive-new", + "ident_case", + "inflections", + "prettyplease", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "cubecl-macros-internal" +version = "0.10.0-pre.2" +source = "git+https://github.com/tracel-ai/cubecl?rev=5b831a3cfac3eca0065fe0dbf57cddf5946d1586#5b831a3cfac3eca0065fe0dbf57cddf5946d1586" +dependencies = [ + "darling 0.23.0", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "cubecl-opt" +version = "0.10.0-pre.2" +source = "git+https://github.com/tracel-ai/cubecl?rev=5b831a3cfac3eca0065fe0dbf57cddf5946d1586#5b831a3cfac3eca0065fe0dbf57cddf5946d1586" +dependencies = [ + "cubecl-common", + "cubecl-core", + "cubecl-ir", + "float-ord", + "log", + "num", + "petgraph", + "smallvec", + "stable-vec", + "type-map", +] + +[[package]] +name = "cubecl-runtime" +version = "0.10.0-pre.2" +source = "git+https://github.com/tracel-ai/cubecl?rev=5b831a3cfac3eca0065fe0dbf57cddf5946d1586#5b831a3cfac3eca0065fe0dbf57cddf5946d1586" +dependencies = [ + "async-channel", + "bytemuck", + "cfg-if", + "cfg_aliases", + "cubecl-common", + "cubecl-ir", + "cubecl-zspace", + "derive-new", + "derive_more", + "dirs", + "enumset", + "hashbrown 0.16.1", + "log", + "md5", + "serde", + "serde_json", + "spin", + "thiserror", + "toml", + "variadics_please", + "wasm-bindgen-futures", + "web-time", +] + +[[package]] +name = "cubecl-wgpu" +version = "0.10.0-pre.2" +source = "git+https://github.com/tracel-ai/cubecl?rev=5b831a3cfac3eca0065fe0dbf57cddf5946d1586#5b831a3cfac3eca0065fe0dbf57cddf5946d1586" +dependencies = [ + "async-channel", + "bytemuck", + "cfg-if", + "cfg_aliases", + "cubecl-common", + "cubecl-core", + "cubecl-ir", + "cubecl-runtime", + "derive-new", + "derive_more", + "half", + "hashbrown 0.16.1", + "log", + "sanitize-filename", + "wgpu", +] + +[[package]] +name = "cubecl-zspace" +version = "0.10.0-pre.2" +source = "git+https://github.com/tracel-ai/cubecl?rev=5b831a3cfac3eca0065fe0dbf57cddf5946d1586#5b831a3cfac3eca0065fe0dbf57cddf5946d1586" +dependencies = [ + "derive-new", + "serde", + "smallvec", +] + +[[package]] +name = "cudarc" +version = "0.19.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f071cd6a7b5d51607df76aa2d426aaabc7a74bc6bdb885b8afa63a880572ad9b" +dependencies = [ + "libloading 0.9.0", +] + +[[package]] +name = "darling" +version = "0.20.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc7f46116c46ff9ab3eb1597a45688b6715c6e628b5c133e288e709a29bcb4ee" +dependencies = [ + "darling_core 0.20.11", + "darling_macro 0.20.11", +] + +[[package]] +name = "darling" +version = "0.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9cdf337090841a411e2a7f3deb9187445851f91b309c0c0a29e05f74a00a48c0" +dependencies = [ + "darling_core 0.21.3", + "darling_macro 0.21.3", +] + +[[package]] +name = "darling" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25ae13da2f202d56bd7f91c25fba009e7717a1e4a1cc98a76d844b65ae912e9d" +dependencies = [ + "darling_core 0.23.0", + "darling_macro 0.23.0", +] + +[[package]] +name = "darling_core" +version = "0.20.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d00b9596d185e565c2207a0b01f8bd1a135483d02d9b7b0a54b11da8d53412e" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim", + "syn", +] + +[[package]] +name = "darling_core" +version = "0.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1247195ecd7e3c85f83c8d2a366e4210d588e802133e1e355180a9870b517ea4" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "darling_core" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9865a50f7c335f53564bb694ef660825eb8610e0a53d3e11bf1b0d3df31e03b0" +dependencies = [ + "ident_case", + "proc-macro2", + "quote", + "strsim", + "syn", +] + +[[package]] +name = "darling_macro" +version = "0.20.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc34b93ccb385b40dc71c6fceac4b2ad23662c7eeb248cf10d529b7e055b6ead" +dependencies = [ + "darling_core 0.20.11", + "quote", + "syn", +] + +[[package]] +name = "darling_macro" +version = "0.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d38308df82d1080de0afee5d069fa14b0326a88c14f15c5ccda35b4a6c414c81" +dependencies = [ + "darling_core 0.21.3", + "quote", + "syn", +] + +[[package]] +name = "darling_macro" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3984ec7bd6cfa798e62b4a642426a5be0e68f9401cfc2a01e3fa9ea2fcdb8d" +dependencies = [ + "darling_core 0.23.0", + "quote", + "syn", +] + +[[package]] +name = "der" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "71fd89660b2dc699704064e59e9dba0147b903e85319429e131620d022be411b" +dependencies = [ + "pem-rfc7468", + "zeroize", +] + +[[package]] +name = "derive-new" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2cdc8d50f426189eef89dac62fabfa0abb27d5cc008f25bf4156a0203325becc" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "derive_more" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d751e9e49156b02b44f9c1815bcb94b984cdcc4396ecc32521c739452808b134" +dependencies = [ + "derive_more-impl", +] + +[[package]] +name = "derive_more-impl" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "799a97264921d8623a957f6c3b9011f3b5492f557bbb7a5a19b7fa6d06ba8dcb" +dependencies = [ + "convert_case", + "proc-macro2", + "quote", + "rustc_version", + "syn", + "unicode-xid", +] + +[[package]] +name = "dirs" +version = "6.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3e8aa94d75141228480295a7d0e7feb620b1a5ad9f12bc40be62411e38cce4e" +dependencies = [ + "dirs-sys", +] + +[[package]] +name = "dirs-sys" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e01a3366d27ee9890022452ee61b2b63a67e6f13f58900b651ff5665f0bb1fab" +dependencies = [ + "libc", + "option-ext", + "redox_users", + "windows-sys", +] + +[[package]] +name = "dispatch2" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e0e367e4e7da84520dedcac1901e4da967309406d1e51017ae1abfb97adbd38" +dependencies = [ + "bitflags", + "objc2", +] + +[[package]] +name = "dlib" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab8ecd87370524b461f8557c119c405552c396ed91fc0a8eec68679eab26f94a" +dependencies = [ + "libloading 0.8.9", +] + +[[package]] +name = "document-features" +version = "0.2.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d4b8a88685455ed29a21542a33abd9cb6510b6b129abadabdcef0f4c55bc8f61" +dependencies = [ + "litrs", +] + +[[package]] +name = "either" +version = "1.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" + +[[package]] +name = "embassy-futures" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc2d050bdc5c21e0862a89256ed8029ae6c290a93aecefc73084b3002cdebb01" + +[[package]] +name = "embassy-time" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "592b0c143ec626e821d4d90da51a2bd91d559d6c442b7c74a47d368c9e23d97a" +dependencies = [ + "cfg-if", + "critical-section", + "document-features", + "embassy-time-driver", + "embedded-hal 0.2.7", + "embedded-hal 1.0.0", + "embedded-hal-async", + "futures-core", +] + +[[package]] +name = "embassy-time-driver" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ee71af1b3a0deaa53eaf2d39252f83504c853646e472400b763060389b9fcc9" +dependencies = [ + "document-features", +] + +[[package]] +name = "embedded-hal" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35949884794ad573cf46071e41c9b60efb0cb311e3ca01f7af807af1debc66ff" +dependencies = [ + "nb 0.1.3", + "void", +] + +[[package]] +name = "embedded-hal" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "361a90feb7004eca4019fb28352a9465666b24f840f5c3cddf0ff13920590b89" + +[[package]] +name = "embedded-hal-async" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c4c685bbef7fe13c3c6dd4da26841ed3980ef33e841cddfa15ce8a8fb3f1884" +dependencies = [ + "embedded-hal 1.0.0", +] + +[[package]] +name = "enumset" +version = "1.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25b07a8dfbbbfc0064c0a6bdf9edcf966de6b1c33ce344bdeca3b41615452634" +dependencies = [ + "enumset_derive", + "serde", +] + +[[package]] +name = "enumset_derive" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f43e744e4ea338060faee68ed933e46e722fb7f3617e722a5772d7e856d8b3ce" +dependencies = [ + "darling 0.21.3", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "equivalent" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" + +[[package]] +name = "errno" +version = "0.3.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" +dependencies = [ + "libc", + "windows-sys", +] + +[[package]] +name = "event-listener" +version = "5.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e13b66accf52311f30a0db42147dadea9850cb48cd070028831ae5f5d4b856ab" +dependencies = [ + "concurrent-queue", + "parking", + "pin-project-lite", +] + +[[package]] +name = "event-listener-strategy" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8be9f3dfaaffdae2972880079a491a1a8bb7cbed0b8dd7a347f668b4150a3b93" +dependencies = [ + "event-listener", + "pin-project-lite", +] + +[[package]] +name = "fastrand" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" + +[[package]] +name = "filetime" +version = "0.2.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f98844151eee8917efc50bd9e8318cb963ae8b297431495d3f758616ea5c57db" +dependencies = [ + "cfg-if", + "libc", + "libredox", +] + +[[package]] +name = "find-msvc-tools" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5baebc0774151f905a1a2cc41989300b1e6fbb29aff0ceffa1064fdd3088d582" + +[[package]] +name = "fixedbitset" +version = "0.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d674e81391d1e1ab681a28d99df07927c6d4aa5b027d7da16ba32d1d21ecd99" + +[[package]] +name = "flate2" +version = "1.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "843fba2746e448b37e26a819579957415c8cef339bf08564fe8b7ddbd959573c" +dependencies = [ + "crc32fast", + "miniz_oxide", +] + +[[package]] +name = "float-ord" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ce81f49ae8a0482e4c55ea62ebbd7e5a686af544c00b9d090bba3ff9be97b3d" + +[[package]] +name = "float4" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a5404bf31d22893d61cf24d4dda149d8e6b2ff07601c3cb3be651031f61a4ed" + +[[package]] +name = "float8" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2d1f04709a8ac06e8e8042875a3c466cc4832d3c1a18dbcb9dba3c6e83046bc" +dependencies = [ + "half", +] + +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + +[[package]] +name = "foldhash" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" + +[[package]] +name = "foldhash" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77ce24cb58228fbb8aa041425bb1050850ac19177686ea6e0f41a70416f56fdb" + +[[package]] +name = "foreign-types" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" +dependencies = [ + "foreign-types-shared", +] + +[[package]] +name = "foreign-types-shared" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" + +[[package]] +name = "futures-core" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e3450815272ef58cec6d564423f6e755e25379b217b0bc688e295ba24df6b1d" + +[[package]] +name = "futures-io" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cecba35d7ad927e23624b22ad55235f2239cfa44fd10428eecbeba6d6a717718" + +[[package]] +name = "futures-lite" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f78e10609fe0e0b3f4157ffab1876319b5b0db102a2c60dc4626306dc46b44ad" +dependencies = [ + "fastrand", + "futures-core", + "futures-io", + "parking", + "pin-project-lite", +] + +[[package]] +name = "futures-task" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "037711b3d59c33004d3856fbdc83b99d4ff37a24768fa1be9ce3538a1cde4393" + +[[package]] +name = "futures-util" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6" +dependencies = [ + "futures-core", + "futures-task", + "pin-project-lite", + "slab", +] + +[[package]] +name = "getrandom" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff2abc00be7fca6ebc474524697ae276ad847ad0a6b3faa4bcb027e9a4614ad0" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + +[[package]] +name = "getrandom" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0de51e6874e94e7bf76d726fc5d13ba782deca734ff60d5bb2fb2607c7406555" +dependencies = [ + "cfg-if", + "libc", + "r-efi", + "rand_core", + "wasip2", + "wasip3", +] + +[[package]] +name = "gimli" +version = "0.32.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e629b9b98ef3dd8afe6ca2bd0f89306cec16d43d907889945bc5d6687f2f13c7" + +[[package]] +name = "gl_generator" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a95dfc23a2b4a9a2f5ab41d194f8bfda3cabec42af4e39f08c339eb2a0c124d" +dependencies = [ + "khronos_api", + "log", + "xml-rs", +] + +[[package]] +name = "glow" +version = "0.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29038e1c483364cc6bb3cf78feee1816002e127c331a1eec55a4d202b9e1adb5" +dependencies = [ + "js-sys", + "slotmap", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "glutin_wgl_sys" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2c4ee00b289aba7a9e5306d57c2d05499b2e5dc427f84ac708bd2c090212cf3e" +dependencies = [ + "gl_generator", +] + +[[package]] +name = "gpu-allocator" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51255ea7cfaadb6c5f1528d43e92a82acb2b96c43365989a28b2d44ee38f8795" +dependencies = [ + "ash", + "hashbrown 0.16.1", + "log", + "presser", + "thiserror", + "windows", +] + +[[package]] +name = "gpu-descriptor" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b89c83349105e3732062a895becfc71a8f921bb71ecbbdd8ff99263e3b53a0ca" +dependencies = [ + "bitflags", + "gpu-descriptor-types", + "hashbrown 0.15.5", +] + +[[package]] +name = "gpu-descriptor-types" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fdf242682df893b86f33a73828fb09ca4b2d3bb6cc95249707fc684d27484b91" +dependencies = [ + "bitflags", +] + +[[package]] +name = "half" +version = "2.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ea2d84b969582b4b1864a92dc5d27cd2b77b622a8d79306834f1be5ba20d84b" +dependencies = [ + "bytemuck", + "cfg-if", + "crunchy", + "num-traits", + "serde", + "zerocopy", +] + +[[package]] +name = "hashbrown" +version = "0.15.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" +dependencies = [ + "foldhash 0.1.5", +] + +[[package]] +name = "hashbrown" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" +dependencies = [ + "allocator-api2", + "equivalent", + "foldhash 0.2.0", + "serde", + "serde_core", +] + +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + +[[package]] +name = "hermit-abi" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c" + +[[package]] +name = "hexf-parse" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dfa686283ad6dd069f105e5ab091b04c62850d3e4cf5d67debad1933f55023df" + +[[package]] +name = "http" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3ba2a386d7f85a81f119ad7498ebe444d2e22c2af0b86b069416ace48b3311a" +dependencies = [ + "bytes", + "itoa", +] + +[[package]] +name = "httparse" +version = "1.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" + +[[package]] +name = "id-arena" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d3067d79b975e8844ca9eb072e16b31c3c1c36928edf9c6789548c524d0d954" + +[[package]] +name = "ident_case" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" + +[[package]] +name = "indexmap" +version = "2.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7714e70437a7dc3ac8eb7e6f8df75fd8eb422675fc7678aff7364301092b1017" +dependencies = [ + "equivalent", + "hashbrown 0.16.1", + "serde", + "serde_core", +] + +[[package]] +name = "inflections" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a257582fdcde896fd96463bf2d40eefea0580021c0712a0e2b028b60b47a837a" + +[[package]] +name = "itertools" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b192c782037fadd9cfa75548310488aabdbf3d2da73885b31bd0abd03351285" +dependencies = [ + "either", +] + +[[package]] +name = "itoa" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f42a60cbdf9a97f5d2305f08a87dc4e09308d1276d28c869c684d7777685682" + +[[package]] +name = "jni-sys" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41a652e1f9b6e0275df1f15b32661cf0d4b78d4d87ddec5e0c3c20f097433258" +dependencies = [ + "jni-sys 0.4.1", +] + +[[package]] +name = "jni-sys" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c6377a88cb3910bee9b0fa88d4f42e1d2da8e79915598f65fb0c7ee14c878af2" +dependencies = [ + "jni-sys-macros", +] + +[[package]] +name = "jni-sys-macros" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38c0b942f458fe50cdac086d2f946512305e5631e720728f2a61aabcd47a6264" +dependencies = [ + "quote", + "syn", +] + +[[package]] +name = "js-sys" +version = "0.3.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc4c90f45aa2e6eacbe8645f77fdea542ac97a494bcd117a67df9ff4d611f995" +dependencies = [ + "cfg-if", + "futures-util", + "once_cell", + "wasm-bindgen", +] + +[[package]] +name = "khronos-egl" +version = "6.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6aae1df220ece3c0ada96b8153459b67eebe9ae9212258bb0134ae60416fdf76" +dependencies = [ + "libc", + "libloading 0.8.9", + "pkg-config", +] + +[[package]] +name = "khronos_api" +version = "3.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2db585e1d738fc771bf08a151420d3ed193d9d895a36df7f6f8a9456b911ddc" + +[[package]] +name = "leb128fmt" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09edd9e8b54e49e587e4f6295a7d29c3ea94d469cb40ab8ca70b288248a81db2" + +[[package]] +name = "libc" +version = "0.2.183" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5b646652bf6661599e1da8901b3b9522896f01e736bad5f723fe7a3a27f899d" + +[[package]] +name = "libloading" +version = "0.8.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7c4b02199fee7c5d21a5ae7d8cfa79a6ef5bb2fc834d6e9058e89c825efdc55" +dependencies = [ + "cfg-if", + "windows-link", +] + +[[package]] +name = "libloading" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "754ca22de805bb5744484a5b151a9e1a8e837d5dc232c2d7d8c2e3492edc8b60" +dependencies = [ + "cfg-if", + "windows-link", +] + +[[package]] +name = "libm" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6d2cec3eae94f9f509c767b45932f1ada8350c4bdb85af2fcab4a3c14807981" + +[[package]] +name = "libredox" +version = "0.1.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ddbf48fd451246b1f8c2610bd3b4ac0cc6e149d89832867093ab69a17194f08" +dependencies = [ + "bitflags", + "libc", + "plain", + "redox_syscall 0.7.3", +] + +[[package]] +name = "linux-raw-sys" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a66949e030da00e8c7d4434b251670a91556f4144941d37452769c25d58a53" + +[[package]] +name = "litrs" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11d3d7f243d5c5a8b9bb5d6dd2b1602c0cb0b9db1621bafc7ed66e35ff9fe092" + +[[package]] +name = "lock_api" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "224399e74b87b5f3557511d98dff8b14089b3dadafcab6bb93eab67d3aace965" +dependencies = [ + "scopeguard", +] + +[[package]] +name = "log" +version = "0.4.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" + +[[package]] +name = "macerator" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09e6046277c48f8a44bd6cfae65a1a261cab6622fb6d4a003f5597e4e4f4a661" +dependencies = [ + "bytemuck", + "cfg_aliases", + "half", + "macerator-macros", + "moddef", + "num-traits", + "paste", + "rustc_version", +] + +[[package]] +name = "macerator-macros" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23ee1819976b67f4d782390c55a75c13401c7a988517f7f8e60a33484dc2e00a" +dependencies = [ + "darling 0.20.11", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "matrixmultiply" +version = "0.3.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a06de3016e9fae57a36fd14dba131fccf49f74b40b7fbdb472f96e361ec71a08" +dependencies = [ + "autocfg", + "num_cpus", + "once_cell", + "rawpointer", + "thread-tree", +] + +[[package]] +name = "md5" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae960838283323069879657ca3de837e9f7bbb4c7bf6ea7f1b290d5e9476d2e0" + +[[package]] +name = "memchr" +version = "2.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" + +[[package]] +name = "miniz_oxide" +version = "0.8.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fa76a2c86f704bdb222d66965fb3d63269ce38518b83cb0575fca855ebb6316" +dependencies = [ + "adler2", + "simd-adler32", +] + +[[package]] +name = "moddef" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a0b3262dc837d2513fe2ef31ff8461352ef932dcca31ba0c0abe33547cf6b9b" + +[[package]] +name = "naga" +version = "29.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa2630921705b9b01dcdd0b6864b9562ca3c1951eecd0f0c4f5f04f61e412647" +dependencies = [ + "arrayvec", + "bit-set", + "bitflags", + "cfg-if", + "cfg_aliases", + "codespan-reporting", + "half", + "hashbrown 0.16.1", + "hexf-parse", + "indexmap", + "libm", + "log", + "num-traits", + "once_cell", + "rustc-hash 1.1.0", + "spirv", + "thiserror", + "unicode-ident", +] + +[[package]] +name = "native-tls" +version = "0.2.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "465500e14ea162429d264d44189adc38b199b62b1c21eea9f69e4b73cb03bbf2" +dependencies = [ + "libc", + "log", + "openssl", + "openssl-probe", + "openssl-sys", + "schannel", + "security-framework", + "security-framework-sys", + "tempfile", +] + +[[package]] +name = "nb" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "801d31da0513b6ec5214e9bf433a77966320625a37860f910be265be6e18d06f" +dependencies = [ + "nb 1.1.0", +] + +[[package]] +name = "nb" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8d5439c4ad607c3c23abf66de8c8bf57ba8adcd1f129e699851a6e43935d339d" + +[[package]] +name = "ndarray" +version = "0.17.2" +dependencies = [ + "blake3", + "cblas-sys", + "libc", + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "portable-atomic", + "portable-atomic-util", + "rawpointer", + "rayon", +] + +[[package]] +name = "ndk-sys" +version = "0.6.0+11769913" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee6cda3051665f1fb8d9e08fc35c96d5a244fb1be711a03b71118828afc9a873" +dependencies = [ + "jni-sys 0.3.1", +] + +[[package]] +name = "netlib-src" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39f41f36bb4d46906d5a72da5b73a804d9de1a7282eb7c89617201acda7b8212" +dependencies = [ + "cmake", +] + +[[package]] +name = "nom" +version = "8.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df9761775871bdef83bee530e60050f7e54b1105350d6884eb0fb4f46c2f9405" +dependencies = [ + "memchr", +] + +[[package]] +name = "num" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35bd024e8b2ff75562e5f34e7f4905839deb4b22955ef5e73d2fea1b9813cb23" +dependencies = [ + "num-bigint", + "num-complex", + "num-integer", + "num-iter", + "num-rational", + "num-traits", +] + +[[package]] +name = "num-bigint" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" +dependencies = [ + "num-integer", + "num-traits", +] + +[[package]] +name = "num-complex" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-iter" +version = "0.1.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1429034a0490724d0075ebb2bc9e875d6503c3cf69e235a8941aa757d83ef5bf" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + +[[package]] +name = "num-rational" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824" +dependencies = [ + "num-bigint", + "num-integer", + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", + "libm", +] + +[[package]] +name = "num_cpus" +version = "1.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91df4bbde75afed763b708b7eee1e8e7651e02d97f6d5dd763e89367e957b23b" +dependencies = [ + "hermit-abi", + "libc", +] + +[[package]] +name = "objc2" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a12a8ed07aefc768292f076dc3ac8c48f3781c8f2d5851dd3d98950e8c5a89f" +dependencies = [ + "objc2-encode", +] + +[[package]] +name = "objc2-core-foundation" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a180dd8642fa45cdb7dd721cd4c11b1cadd4929ce112ebd8b9f5803cc79d536" +dependencies = [ + "bitflags", + "dispatch2", + "objc2", +] + +[[package]] +name = "objc2-encode" +version = "4.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef25abbcd74fb2609453eb695bd2f860d389e457f67dc17cafc8b8cbc89d0c33" + +[[package]] +name = "objc2-foundation" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3e0adef53c21f888deb4fa59fc59f7eb17404926ee8a6f59f5df0fd7f9f3272" +dependencies = [ + "bitflags", + "objc2", + "objc2-core-foundation", +] + +[[package]] +name = "objc2-metal" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a0125f776a10d00af4152d74616409f0d4a2053a6f57fa5b7d6aa2854ac04794" +dependencies = [ + "bitflags", + "block2", + "objc2", + "objc2-foundation", +] + +[[package]] +name = "objc2-quartz-core" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96c1358452b371bf9f104e21ec536d37a650eb10f7ee379fff67d2e08d537f1f" +dependencies = [ + "bitflags", + "objc2", + "objc2-core-foundation", + "objc2-foundation", + "objc2-metal", +] + +[[package]] +name = "object" +version = "0.37.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff76201f031d8863c38aa7f905eca4f53abbfa15f609db4277d44cd8938f33fe" +dependencies = [ + "memchr", +] + +[[package]] +name = "once_cell" +version = "1.21.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f7c3e4beb33f85d45ae3e3a1792185706c8e16d043238c593331cc7cd313b50" + +[[package]] +name = "oneshot" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfe21416a02c693fb9f980befcb230ecc70b0b3d1cc4abf88b9675c4c1457f0c" + +[[package]] +name = "openblas-build" +version = "0.10.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd235aa8876fa5c4be452efde09b9b8bafa19aea0bf14a4926508213082439a3" +dependencies = [ + "anyhow", + "cc", + "flate2", + "tar", + "thiserror", + "ureq", +] + +[[package]] +name = "openblas-src" +version = "0.10.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fccd2c4f5271ab871f2069cb6f1a13ef2c0db50e1145ce03428ee541f4c63c4f" +dependencies = [ + "dirs", + "openblas-build", + "pkg-config", + "vcpkg", +] + +[[package]] +name = "openssl" +version = "0.10.76" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "951c002c75e16ea2c65b8c7e4d3d51d5530d8dfa7d060b4776828c88cfb18ecf" +dependencies = [ + "bitflags", + "cfg-if", + "foreign-types", + "libc", + "once_cell", + "openssl-macros", + "openssl-sys", +] + +[[package]] +name = "openssl-macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "openssl-probe" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c87def4c32ab89d880effc9e097653c8da5d6ef28e6b539d313baaacfbafcbe" + +[[package]] +name = "openssl-sys" +version = "0.9.112" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57d55af3b3e226502be1526dfdba67ab0e9c96fc293004e79576b2b9edb0dbdb" +dependencies = [ + "cc", + "libc", + "pkg-config", + "vcpkg", +] + +[[package]] +name = "option-ext" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" + +[[package]] +name = "ordered-float" +version = "5.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7d950ca161dc355eaf28f82b11345ed76c6e1f6eb1f4f4479e0323b9e2fbd0e" +dependencies = [ + "num-traits", +] + +[[package]] +name = "parking" +version = "2.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f38d5652c16fde515bb1ecef450ab0f6a219d619a7274976324d5e377f7dceba" + +[[package]] +name = "parking_lot" +version = "0.12.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93857453250e3077bd71ff98b6a65ea6621a19bb0f559a85248955ac12c45a1a" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2621685985a2ebf1c516881c026032ac7deafcda1a2c9b7850dc81e3dfcb64c1" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall 0.5.18", + "smallvec", + "windows-link", +] + +[[package]] +name = "paste" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" + +[[package]] +name = "pem-rfc7468" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6305423e0e7738146434843d1694d621cce767262b2a86910beab705e4493d9" +dependencies = [ + "base64ct", +] + +[[package]] +name = "percent-encoding" +version = "2.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220" + +[[package]] +name = "petgraph" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8701b58ea97060d5e5b155d383a69952a60943f0e6dfe30b04c287beb0b27455" +dependencies = [ + "fixedbitset", + "hashbrown 0.15.5", + "indexmap", + "serde", +] + +[[package]] +name = "pin-project-lite" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a89322df9ebe1c1578d689c92318e070967d1042b512afbe49518723f4e6d5cd" + +[[package]] +name = "pkg-config" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" + +[[package]] +name = "plain" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4596b6d070b27117e987119b4dac604f3c58cfb0b191112e24771b2faeac1a6" + +[[package]] +name = "portable-atomic" +version = "1.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49" +dependencies = [ + "serde", +] + +[[package]] +name = "portable-atomic-util" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "091397be61a01d4be58e7841595bd4bfedb15f1cd54977d79b8271e94ed799a3" +dependencies = [ + "portable-atomic", +] + +[[package]] +name = "presser" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8cf8e6a8aa66ce33f63993ffc4ea4271eb5b0530a9002db8455ea6050c77bfa" + +[[package]] +name = "prettyplease" +version = "0.2.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" +dependencies = [ + "proc-macro2", + "syn", +] + +[[package]] +name = "proc-macro2" +version = "1.0.106" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "profiling" +version = "1.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3eb8486b569e12e2c32ad3e204dbaba5e4b5b216e9367044f25f1dba42341773" + +[[package]] +name = "quote" +version = "1.0.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41f2619966050689382d2b44f664f4bc593e129785a36d6ee376ddf37259b924" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "r-efi" +version = "6.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8dcc9c7d52a811697d2151c701e0d08956f92b0e24136cf4cf27b57a6a0d9bf" + +[[package]] +name = "rand" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc266eb313df6c5c09c1c7b1fbe2510961e5bcd3add930c1e31f7ed9da0feff8" +dependencies = [ + "chacha20", + "getrandom 0.4.2", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c8d0fd677905edcbeedbf2edb6494d676f0e98d54d5cf9bda0b061cb8fb8aba" + +[[package]] +name = "rand_distr" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4d431c2703ccf129de4d45253c03f49ebb22b97d6ad79ee3ecfc7e3f4862c1d8" +dependencies = [ + "num-traits", + "rand", +] + +[[package]] +name = "range-alloc" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ca45419789ae5a7899559e9512e58ca889e41f04f1f2445e9f4b290ceccd1d08" + +[[package]] +name = "raw-window-handle" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "20675572f6f24e9e76ef639bc5552774ed45f1c30e2951e1e99c59888861c539" + +[[package]] +name = "raw-window-metal" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40d213455a5f1dc59214213c7330e074ddf8114c9a42411eb890c767357ce135" +dependencies = [ + "objc2", + "objc2-core-foundation", + "objc2-foundation", + "objc2-quartz-core", +] + +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + +[[package]] +name = "rayon" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "368f01d005bf8fd9b1206fb6fa653e6c4a81ceb1466406b81792d87c5677a58f" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22e18b0f0062d30d4230b2e85ff77fdfe4326feb054b9783a3460d8435c8ab91" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + +[[package]] +name = "redox_syscall" +version = "0.5.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed2bf2547551a7053d6fdfafda3f938979645c44812fbfcda098faae3f1a362d" +dependencies = [ + "bitflags", +] + +[[package]] +name = "redox_syscall" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ce70a74e890531977d37e532c34d45e9055d2409ed08ddba14529471ed0be16" +dependencies = [ + "bitflags", +] + +[[package]] +name = "redox_users" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4e608c6638b9c18977b00b475ac1f28d14e84b27d8d42f70e0bf1e3dec127ac" +dependencies = [ + "getrandom 0.2.17", + "libredox", + "thiserror", +] + +[[package]] +name = "regex" +version = "1.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e10754a14b9137dd7b1e3e5b0493cc9171fdd105e0ab477f51b72e7f3ac0e276" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e1dd4122fc1595e8162618945476892eefca7b88c52820e74af6262213cae8f" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc897dd8d9e8bd1ed8cdad82b5966c3e0ecae09fb1907d58efaa013543185d0a" + +[[package]] +name = "renderdoc-sys" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19b30a45b0cd0bcca8037f3d0dc3421eaf95327a17cad11964fb8179b4fc4832" + +[[package]] +name = "rustc-demangle" +version = "0.1.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b50b8869d9fc858ce7266cce0194bd74df58b9d0e3f6df3a9fc8eb470d95c09d" + +[[package]] +name = "rustc-hash" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" + +[[package]] +name = "rustc-hash" +version = "2.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94300abf3f1ae2e2b8ffb7b58043de3d399c73fa6f4b73826402a5c457614dbe" + +[[package]] +name = "rustc_version" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92" +dependencies = [ + "semver", +] + +[[package]] +name = "rustix" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6fe4565b9518b83ef4f91bb47ce29620ca828bd32cb7e408f0062e9930ba190" +dependencies = [ + "bitflags", + "errno", + "libc", + "linux-raw-sys", + "windows-sys", +] + +[[package]] +name = "rustls-pki-types" +version = "1.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be040f8b0a225e40375822a563fa9524378b9d63112f53e19ffff34df5d33fdd" +dependencies = [ + "zeroize", +] + +[[package]] +name = "rustversion" +version = "1.0.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" + +[[package]] +name = "sanitize-filename" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc984f4f9ceb736a7bb755c3e3bd17dc56370af2600c9780dcc48c66453da34d" +dependencies = [ + "regex", +] + +[[package]] +name = "schannel" +version = "0.1.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91c1b7e4904c873ef0710c1f407dde2e6287de2bebc1bbbf7d430bb7cbffd939" +dependencies = [ + "windows-sys", +] + +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + +[[package]] +name = "security-framework" +version = "3.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7f4bc775c73d9a02cde8bf7b2ec4c9d12743edf609006c7facc23998404cd1d" +dependencies = [ + "bitflags", + "core-foundation", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework-sys" +version = "2.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ce2691df843ecc5d231c0b14ece2acc3efb62c0a398c7e1d875f3983ce020e3" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "semver" +version = "1.0.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d767eb0aabc880b29956c35734170f26ed551a859dbd361d140cdbeca61ab1e2" + +[[package]] +name = "seq-macro" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bc711410fbe7399f390ca1c3b60ad0f53f80e95c5eb935e52268a0e2cd49acc" + +[[package]] +name = "serde" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" +dependencies = [ + "serde_core", + "serde_derive", +] + +[[package]] +name = "serde_bytes" +version = "0.11.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5d440709e79d88e51ac01c4b72fc6cb7314017bb7da9eeff678aa94c10e3ea8" +dependencies = [ + "serde", + "serde_core", +] + +[[package]] +name = "serde_core" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.149" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86" +dependencies = [ + "itoa", + "memchr", + "serde", + "serde_core", + "zmij", +] + +[[package]] +name = "serde_spanned" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "876ac351060d4f882bb1032b6369eb0aef79ad9df1ea8bc404874d8cc3d0cd98" +dependencies = [ + "serde_core", +] + +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + +[[package]] +name = "simd-adler32" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "703d5c7ef118737c72f1af64ad2f6f8c5e1921f818cdcb97b8fe6fc69bf66214" + +[[package]] +name = "slab" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c790de23124f9ab44544d7ac05d60440adc586479ce501c1d6d7da3cd8c9cf5" + +[[package]] +name = "slotmap" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bdd58c3c93c3d278ca835519292445cb4b0d4dc59ccfdf7ceadaab3f8aeb4038" +dependencies = [ + "version_check", +] + +[[package]] +name = "smallvec" +version = "1.15.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" +dependencies = [ + "serde", +] + +[[package]] +name = "spin" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d5fe4ccb98d9c292d56fec89a5e07da7fc4cf0dc11e156b41793132775d3e591" +dependencies = [ + "lock_api", + "portable-atomic", +] + +[[package]] +name = "spirv" +version = "0.4.0+sdk-1.4.341.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9571ea910ebd84c86af4b3ed27f9dbdc6ad06f17c5f96146b2b671e2976744f" +dependencies = [ + "bitflags", +] + +[[package]] +name = "stable-vec" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6dac7bc0f7d0d44329b200020effbc25a534d89fa142af95e3ddf76113412a5e" + +[[package]] +name = "static_assertions" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" + +[[package]] +name = "strsim" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" + +[[package]] +name = "syn" +version = "2.0.117" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e665b8803e7b1d2a727f4023456bbbbe74da67099c585258af0ad9c5013b9b99" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "tar" +version = "0.4.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22692a6476a21fa75fdfc11d452fda482af402c008cdbaf3476414e122040973" +dependencies = [ + "filetime", + "libc", + "xattr", +] + +[[package]] +name = "tempfile" +version = "3.27.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32497e9a4c7b38532efcdebeef879707aa9f794296a4f0244f6f69e9bc8574bd" +dependencies = [ + "fastrand", + "getrandom 0.4.2", + "once_cell", + "rustix", + "windows-sys", +] + +[[package]] +name = "termcolor" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06794f8f6c5c898b3275aebefa6b8a1cb24cd2c6c79397ab15774837a0bc5755" +dependencies = [ + "winapi-util", +] + +[[package]] +name = "thiserror" +version = "2.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4288b5bcbc7920c07a1149a35cf9590a2aa808e0bc1eafaade0b80947865fbc4" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "2.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc4ee7f67670e9b64d05fa4253e753e016c6c95ff35b89b7941d6b856dec1d5" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "thread-tree" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ffbd370cb847953a25954d9f63e14824a36113f8c72eecf6eccef5dc4b45d630" +dependencies = [ + "crossbeam-channel", +] + +[[package]] +name = "tiny-keccak" +version = "2.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2c9d3793400a45f954c52e73d068316d76b6f4e36977e3fcebb13a2721e80237" +dependencies = [ + "crunchy", +] + +[[package]] +name = "toml" +version = "1.1.0+spec-1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8195ca05e4eb728f4ba94f3e3291661320af739c4e43779cbdfae82ab239fcc" +dependencies = [ + "indexmap", + "serde_core", + "serde_spanned", + "toml_datetime", + "toml_parser", + "toml_writer", + "winnow", +] + +[[package]] +name = "toml_datetime" +version = "1.1.0+spec-1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97251a7c317e03ad83774a8752a7e81fb6067740609f75ea2b585b569a59198f" +dependencies = [ + "serde_core", +] + +[[package]] +name = "toml_parser" +version = "1.1.0+spec-1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2334f11ee363607eb04df9b8fc8a13ca1715a72ba8662a26ac285c98aabb4011" +dependencies = [ + "winnow", +] + +[[package]] +name = "toml_writer" +version = "1.1.0+spec-1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d282ade6016312faf3e41e57ebbba0c073e4056dab1232ab1cb624199648f8ed" + +[[package]] +name = "tynm" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a21cdb0fc8f85c98b1ec812bc4cd69faf6c0fa2fc17d44ea3c2cdd38dc08e999" +dependencies = [ + "nom", +] + +[[package]] +name = "type-map" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb30dbbd9036155e74adad6812e9898d03ec374946234fbcebd5dfc7b9187b90" +dependencies = [ + "rustc-hash 2.1.2", +] + +[[package]] +name = "unicode-ident" +version = "1.0.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" + +[[package]] +name = "unicode-segmentation" +version = "1.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9629274872b2bfaf8d66f5f15725007f635594914870f65218920345aa11aa8c" + +[[package]] +name = "unicode-width" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4ac048d71ede7ee76d585517add45da530660ef4390e49b098733c6e897f254" + +[[package]] +name = "unicode-xid" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" + +[[package]] +name = "unty" +version = "0.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d49784317cd0d1ee7ec5c716dd598ec5b4483ea832a2dced265471cc0f690ae" + +[[package]] +name = "ureq" +version = "3.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dea7109cdcd5864d4eeb1b58a1648dc9bf520360d7af16ec26d0a9354bafcfc0" +dependencies = [ + "base64", + "der", + "log", + "native-tls", + "percent-encoding", + "rustls-pki-types", + "ureq-proto", + "utf8-zero", + "webpki-root-certs", +] + +[[package]] +name = "ureq-proto" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e994ba84b0bd1b1b0cf92878b7ef898a5c1760108fe7b6010327e274917a808c" +dependencies = [ + "base64", + "http", + "httparse", + "log", +] + +[[package]] +name = "utf8-zero" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8c0a043c9540bae7c578c88f91dda8bd82e59ae27c21baca69c8b191aaf5a6e" + +[[package]] +name = "variadics_please" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41b6d82be61465f97d42bd1d15bf20f3b0a3a0905018f38f9d6f6962055b0b5c" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "vcpkg" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" + +[[package]] +name = "version_check" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" + +[[package]] +name = "void" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a02e4885ed3bc0f2de90ea6dd45ebcbb66dacffe03547fadbb0eeae2770887d" + +[[package]] +name = "wasi" +version = "0.11.1+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" + +[[package]] +name = "wasip2" +version = "1.0.2+wasi-0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9517f9239f02c069db75e65f174b3da828fe5f5b945c4dd26bd25d89c03ebcf5" +dependencies = [ + "wit-bindgen", +] + +[[package]] +name = "wasip3" +version = "0.4.0+wasi-0.3.0-rc-2026-01-06" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5428f8bf88ea5ddc08faddef2ac4a67e390b88186c703ce6dbd955e1c145aca5" +dependencies = [ + "wit-bindgen", +] + +[[package]] +name = "wasm-bindgen" +version = "0.2.115" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6523d69017b7633e396a89c5efab138161ed5aafcbc8d3e5c5a42ae38f50495a" +dependencies = [ + "cfg-if", + "once_cell", + "rustversion", + "wasm-bindgen-macro", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-futures" +version = "0.4.65" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d1faf851e778dfa54db7cd438b70758eba9755cb47403f3496edd7c8fc212f0" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.115" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e3a6c758eb2f701ed3d052ff5737f5bfe6614326ea7f3bbac7156192dc32e67" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.115" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "921de2737904886b52bcbb237301552d05969a6f9c40d261eb0533c8b055fedf" +dependencies = [ + "bumpalo", + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.115" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a93e946af942b58934c604527337bad9ae33ba1d5c6900bbb41c2c07c2364a93" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "wasm-encoder" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "990065f2fe63003fe337b932cfb5e3b80e0b4d0f5ff650e6985b1048f62c8319" +dependencies = [ + "leb128fmt", + "wasmparser", +] + +[[package]] +name = "wasm-metadata" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb0e353e6a2fbdc176932bbaab493762eb1255a7900fe0fea1a2f96c296cc909" +dependencies = [ + "anyhow", + "indexmap", + "wasm-encoder", + "wasmparser", +] + +[[package]] +name = "wasmparser" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47b807c72e1bac69382b3a6fb3dbe8ea4c0ed87ff5629b8685ae6b9a611028fe" +dependencies = [ + "bitflags", + "hashbrown 0.15.5", + "indexmap", + "semver", +] + +[[package]] +name = "wayland-sys" +version = "0.31.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "374f6b70e8e0d6bf9461a32988fd553b59ff630964924dad6e4a4eb6bd538d17" +dependencies = [ + "dlib", + "log", + "once_cell", + "pkg-config", +] + +[[package]] +name = "web-sys" +version = "0.3.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "84cde8507f4d7cfcb1185b8cb5890c494ffea65edbe1ba82cfd63661c805ed94" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "web-time" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "webpki-root-certs" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "804f18a4ac2676ffb4e8b5b5fa9ae38af06df08162314f96a68d2a363e21a8ca" +dependencies = [ + "rustls-pki-types", +] + +[[package]] +name = "wgpu" +version = "29.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72c239a9a747bbd379590985bac952c2e53cb19873f7072b3370c6a6a8e06837" +dependencies = [ + "arrayvec", + "bitflags", + "bytemuck", + "cfg-if", + "cfg_aliases", + "document-features", + "hashbrown 0.16.1", + "js-sys", + "log", + "naga", + "parking_lot", + "portable-atomic", + "profiling", + "raw-window-handle", + "smallvec", + "static_assertions", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", + "wgpu-core", + "wgpu-hal", + "wgpu-types", +] + +[[package]] +name = "wgpu-core" +version = "29.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e80ac6cf1895df6342f87d975162108f9d98772a0d74bc404ab7304ac29469e" +dependencies = [ + "arrayvec", + "bit-set", + "bit-vec", + "bitflags", + "bytemuck", + "cfg_aliases", + "document-features", + "hashbrown 0.16.1", + "indexmap", + "log", + "naga", + "once_cell", + "parking_lot", + "portable-atomic", + "profiling", + "raw-window-handle", + "rustc-hash 1.1.0", + "smallvec", + "thiserror", + "wgpu-core-deps-apple", + "wgpu-core-deps-emscripten", + "wgpu-core-deps-windows-linux-android", + "wgpu-hal", + "wgpu-naga-bridge", + "wgpu-types", +] + +[[package]] +name = "wgpu-core-deps-apple" +version = "29.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43acd053312501689cd92a01a9638d37f3e41a5fd9534875efa8917ee2d11ac0" +dependencies = [ + "wgpu-hal", +] + +[[package]] +name = "wgpu-core-deps-emscripten" +version = "29.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef043bf135cc68b6f667c55ff4e345ce2b5924d75bad36a47921b0287ca4b24a" +dependencies = [ + "wgpu-hal", +] + +[[package]] +name = "wgpu-core-deps-windows-linux-android" +version = "29.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "725d5c006a8c02967b6d93ef04f6537ec4593313e330cfe86d9d3f946eb90f28" +dependencies = [ + "wgpu-hal", +] + +[[package]] +name = "wgpu-hal" +version = "29.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "89a47aef47636562f3937285af4c44b4b5b404b46577471411cc5313a921da7e" +dependencies = [ + "android_system_properties", + "arrayvec", + "ash", + "bit-set", + "bitflags", + "block2", + "bytemuck", + "cfg-if", + "cfg_aliases", + "glow", + "glutin_wgl_sys", + "gpu-allocator", + "gpu-descriptor", + "hashbrown 0.16.1", + "js-sys", + "khronos-egl", + "libc", + "libloading 0.8.9", + "log", + "naga", + "ndk-sys", + "objc2", + "objc2-core-foundation", + "objc2-foundation", + "objc2-metal", + "objc2-quartz-core", + "once_cell", + "ordered-float", + "parking_lot", + "portable-atomic", + "portable-atomic-util", + "profiling", + "range-alloc", + "raw-window-handle", + "raw-window-metal", + "renderdoc-sys", + "smallvec", + "thiserror", + "wasm-bindgen", + "wayland-sys", + "web-sys", + "wgpu-naga-bridge", + "wgpu-types", + "windows", + "windows-core", +] + +[[package]] +name = "wgpu-naga-bridge" +version = "29.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b4684f4410da0cf95a4cb63bb5edaac022461dedb6adf0b64d0d9b5f6890d51" +dependencies = [ + "naga", + "wgpu-types", +] + +[[package]] +name = "wgpu-types" +version = "29.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec2675540fb1a5cfa5ef122d3d5f390e2c75711a0b946410f2d6ac3a0f77d1f6" +dependencies = [ + "bitflags", + "bytemuck", + "js-sys", + "log", + "raw-window-handle", + "web-sys", +] + +[[package]] +name = "winapi-util" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" +dependencies = [ + "windows-sys", +] + +[[package]] +name = "windows" +version = "0.62.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "527fadee13e0c05939a6a05d5bd6eec6cd2e3dbd648b9f8e447c6518133d8580" +dependencies = [ + "windows-collections", + "windows-core", + "windows-future", + "windows-numerics", +] + +[[package]] +name = "windows-collections" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23b2d95af1a8a14a3c7367e1ed4fc9c20e0a26e79551b1454d72583c97cc6610" +dependencies = [ + "windows-core", +] + +[[package]] +name = "windows-core" +version = "0.62.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8e83a14d34d0623b51dce9581199302a221863196a1dde71a7663a4c2be9deb" +dependencies = [ + "windows-implement", + "windows-interface", + "windows-link", + "windows-result", + "windows-strings", +] + +[[package]] +name = "windows-future" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1d6f90251fe18a279739e78025bd6ddc52a7e22f921070ccdc67dde84c605cb" +dependencies = [ + "windows-core", + "windows-link", + "windows-threading", +] + +[[package]] +name = "windows-implement" +version = "0.60.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "053e2e040ab57b9dc951b72c264860db7eb3b0200ba345b4e4c3b14f67855ddf" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "windows-interface" +version = "0.59.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f316c4a2570ba26bbec722032c4099d8c8bc095efccdc15688708623367e358" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "windows-link" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" + +[[package]] +name = "windows-numerics" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e2e40844ac143cdb44aead537bbf727de9b044e107a0f1220392177d15b0f26" +dependencies = [ + "windows-core", + "windows-link", +] + +[[package]] +name = "windows-result" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7781fa89eaf60850ac3d2da7af8e5242a5ea78d1a11c49bf2910bb5a73853eb5" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-strings" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7837d08f69c77cf6b07689544538e017c1bfcf57e34b4c0ff58e6c2cd3b37091" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-sys" +version = "0.61.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-threading" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3949bd5b99cafdf1c7ca86b43ca564028dfe27d66958f2470940f73d86d75b37" +dependencies = [ + "windows-link", +] + +[[package]] +name = "winnow" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a90e88e4667264a994d34e6d1ab2d26d398dcdca8b7f52bec8668957517fc7d8" + +[[package]] +name = "wit-bindgen" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7249219f66ced02969388cf2bb044a09756a083d0fab1e566056b04d9fbcaa5" +dependencies = [ + "wit-bindgen-rust-macro", +] + +[[package]] +name = "wit-bindgen-core" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ea61de684c3ea68cb082b7a88508a8b27fcc8b797d738bfc99a82facf1d752dc" +dependencies = [ + "anyhow", + "heck", + "wit-parser", +] + +[[package]] +name = "wit-bindgen-rust" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7c566e0f4b284dd6561c786d9cb0142da491f46a9fbed79ea69cdad5db17f21" +dependencies = [ + "anyhow", + "heck", + "indexmap", + "prettyplease", + "syn", + "wasm-metadata", + "wit-bindgen-core", + "wit-component", +] + +[[package]] +name = "wit-bindgen-rust-macro" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c0f9bfd77e6a48eccf51359e3ae77140a7f50b1e2ebfe62422d8afdaffab17a" +dependencies = [ + "anyhow", + "prettyplease", + "proc-macro2", + "quote", + "syn", + "wit-bindgen-core", + "wit-bindgen-rust", +] + +[[package]] +name = "wit-component" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d66ea20e9553b30172b5e831994e35fbde2d165325bec84fc43dbf6f4eb9cb2" +dependencies = [ + "anyhow", + "bitflags", + "indexmap", + "log", + "serde", + "serde_derive", + "serde_json", + "wasm-encoder", + "wasm-metadata", + "wasmparser", + "wit-parser", +] + +[[package]] +name = "wit-parser" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ecc8ac4bc1dc3381b7f59c34f00b67e18f910c2c0f50015669dde7def656a736" +dependencies = [ + "anyhow", + "id-arena", + "indexmap", + "log", + "semver", + "serde", + "serde_derive", + "serde_json", + "unicode-xid", + "wasmparser", +] + +[[package]] +name = "xattr" +version = "1.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32e45ad4206f6d2479085147f02bc2ef834ac85886624a23575ae137c8aa8156" +dependencies = [ + "libc", + "rustix", +] + +[[package]] +name = "xml-rs" +version = "0.8.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3ae8337f8a065cfc972643663ea4279e04e7256de865aa66fe25cec5fb912d3f" + +[[package]] +name = "xxhash-rust" +version = "0.8.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fdd20c5420375476fbd4394763288da7eb0cc0b8c11deed431a91562af7335d3" + +[[package]] +name = "zerocopy" +version = "0.8.48" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eed437bf9d6692032087e337407a86f04cd8d6a16a37199ed57949d415bd68e9" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.8.48" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70e3cd084b1788766f53af483dd21f93881ff30d7320490ec3ef7526d203bad4" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "zeroize" +version = "1.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b97154e67e32c85465826e8bcc1c59429aaaf107c1e4a9e53c8d8ccd5eff88d0" + +[[package]] +name = "zmij" +version = "1.0.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8848ee67ecc8aedbaf3e4122217aff892639231befc6a1b58d29fff4c2cabaa" diff --git a/crates/burn/Cargo.toml b/crates/burn/Cargo.toml new file mode 100644 index 00000000..6d26b361 --- /dev/null +++ b/crates/burn/Cargo.toml @@ -0,0 +1,76 @@ +[package] +name = "burn" +version = "0.1.0" +edition = "2024" +license = "MIT OR Apache-2.0" +publish = false +description = """ +Burn ndarray backend forked into adaworldapi/ndarray for SIMD augmentation. +Source: upstream burn-ndarray (tracel-ai/burn, v0.21.0-pre.2). +Goal: replace macerator SIMD with crate::simd F32x16 + LazyLock dispatch, +add bgz-tensor AttentionTable compiled attention path. +""" + +[features] +default = ["std", "simd", "multi-threads"] +multi-threads = ["rayon", "ndarray/rayon", "matrixmultiply/threading"] +simd = ["macerator", "bytemuck", "seq-macro", "itertools"] +std = [ + "burn-std/std", + "burn-backend/std", + "burn-ir/std", + "ndarray/std", + "matrixmultiply/std", + "rand/std", + "rand/std_rng", + "num-traits/std", + "macerator/std", +] +blas-openblas = ["blas-src/openblas", "ndarray/blas", "openblas-src"] +blas-openblas-system = ["blas-src/openblas", "ndarray/blas", "openblas-src/system"] +blas-netlib = ["blas-src/netlib", "ndarray/blas"] +export_tests = [] + +[dependencies] +# Upstream burn crates (from git main — matches source code we copied) +# Upstream burn crates — vendored at pinned commit, we only override our additions. +# Our changes: crates/burn/src/ops/tensor.rs (try_vml_unary + 4 SIMD wires) +# crates/burn/src/ops/activation.rs (fused sigmoid) +burn-backend = { git = "https://github.com/tracel-ai/burn.git", rev = "ed72d2b", default-features = false } +burn-std = { git = "https://github.com/tracel-ai/burn.git", rev = "ed72d2b", default-features = false } +burn-ir = { git = "https://github.com/tracel-ai/burn.git", rev = "ed72d2b", default-features = false } + +# ndarray — uses our workspace root (adaworldapi/ndarray with SIMD + HPC) +ndarray = { path = "../..", default-features = false } + +# Matrix multiply +matrixmultiply = { version = "0.3", default-features = false } + +# Element traits +num-traits = { version = "0.2", default-features = false } +libm = "0.2" +atomic_float = "1" +const-random = "0.1" +paste = "1" + +# Random +rand = { version = "0.10", default-features = false, features = ["std_rng"] } + +# Serialization +serde = { version = "1", features = ["derive"] } + +# SIMD (macerator — upstream burn's choice, will augment with crate::simd) +macerator = { version = "0.3", default-features = false, optional = true } +bytemuck = { version = "1", optional = true } +seq-macro = { version = "0.3", optional = true } +itertools = { version = "0.14", optional = true } + +# Parallel +rayon = { version = "1", optional = true } + +# BLAS (optional) +blas-src = { version = "0.10", default-features = false, optional = true } +openblas-src = { version = "0.10", optional = true } + +[dev-dependencies] +bytes = "1" diff --git a/crates/burn/src/backend.rs b/crates/burn/src/backend.rs new file mode 100644 index 00000000..6a27a9fd --- /dev/null +++ b/crates/burn/src/backend.rs @@ -0,0 +1,222 @@ +use crate::rand::NdArrayRng; +use crate::{NdArrayQTensor, NdArrayTensor}; +use crate::{ + SharedArray, + element::{FloatNdArrayElement, IntNdArrayElement, QuantElement}, +}; +use alloc::string::String; +use burn_backend::quantization::{QuantLevel, QuantMode, QuantScheme, QuantStore, QuantValue}; +use burn_backend::tensor::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor}; +use burn_backend::{Backend, DType, DeviceId, DeviceOps}; +use burn_ir::{BackendIr, HandleKind, TensorHandle}; +use burn_std::BoolStore; +use burn_std::stub::Mutex; +use core::marker::PhantomData; +use rand::SeedableRng; + +pub(crate) static SEED: Mutex> = Mutex::new(None); + +/// The device type for the ndarray backend. +#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)] +pub enum NdArrayDevice { + /// The CPU device. + #[default] + Cpu, +} + +impl DeviceOps for NdArrayDevice {} + +impl burn_backend::Device for NdArrayDevice { + fn from_id(_device_id: DeviceId) -> Self { + Self::Cpu + } + + fn to_id(&self) -> DeviceId { + DeviceId { + type_id: 0, + index_id: 0, + } + } +} + +/// Tensor backend that uses the [ndarray](ndarray) crate for executing tensor operations. +/// +/// This backend is compatible with CPUs and can be compiled for almost any platform, including +/// `wasm`, `arm`, and `x86`. +#[derive(Clone, Copy, Default, Debug)] +pub struct NdArray +where + NdArrayTensor: From>, + NdArrayTensor: From>, +{ + _e: PhantomData, + _i: PhantomData, + _q: PhantomData, +} + +impl Backend for NdArray +where + NdArrayTensor: From>, + NdArrayTensor: From>, +{ + type Device = NdArrayDevice; + + type FloatTensorPrimitive = NdArrayTensor; + type FloatElem = E; + + type IntTensorPrimitive = NdArrayTensor; + type IntElem = I; + + type BoolTensorPrimitive = NdArrayTensor; + type BoolElem = bool; + + type QuantizedTensorPrimitive = NdArrayQTensor; + + fn ad_enabled(_device: &Self::Device) -> bool { + false + } + + fn name(_device: &Self::Device) -> String { + String::from("ndarray") + } + + fn seed(_device: &Self::Device, seed: u64) { + let rng = NdArrayRng::seed_from_u64(seed); + let mut seed = SEED.lock().unwrap(); + *seed = Some(rng); + } + + fn dtype_usage(_device: &Self::Device, dtype: DType) -> burn_backend::DTypeUsageSet { + match dtype { + DType::F64 + | DType::F32 + | DType::Flex32 + | DType::I64 + | DType::I32 + | DType::I16 + | DType::I8 + | DType::U64 + | DType::U32 + | DType::U16 + | DType::U8 + | DType::Bool(BoolStore::Native) => burn_backend::DTypeUsage::general(), + DType::F16 | DType::BF16 | DType::Bool(_) => burn_backend::DTypeUsageSet::empty(), + DType::QFloat(scheme) => { + match scheme { + QuantScheme { + level: QuantLevel::Tensor | QuantLevel::Block(_), + mode: QuantMode::Symmetric, + #[cfg(not(feature = "export_tests"))] + value: QuantValue::Q8F | QuantValue::Q8S, + // For tests, "native" sub-byte quant serves as a reference for value equality. + // Values are stored as i8 regardless. + #[cfg(feature = "export_tests")] + value: + QuantValue::Q8F + | QuantValue::Q8S + | QuantValue::Q4F + | QuantValue::Q4S + | QuantValue::Q2F + | QuantValue::Q2S, + store: QuantStore::Native, + .. + } => burn_backend::DTypeUsage::general(), + _scheme => burn_backend::DTypeUsageSet::empty(), + } + } + } + } + + fn device_count(_: u16) -> usize { + 1 + } +} + +impl BackendIr for NdArray +where + NdArrayTensor: From>, + NdArrayTensor: From>, +{ + type Handle = HandleKind; + + fn float_tensor(handle: TensorHandle) -> FloatTensor { + match handle.handle { + HandleKind::Float(handle) => handle, + _ => panic!("Expected float handle, got {}", handle.handle.name()), + } + } + + fn int_tensor(handle: TensorHandle) -> IntTensor { + match handle.handle { + HandleKind::Int(handle) => handle, + _ => panic!("Expected int handle, got {}", handle.handle.name()), + } + } + + fn bool_tensor(handle: TensorHandle) -> BoolTensor { + match handle.handle { + HandleKind::Bool(handle) => handle, + _ => panic!("Expected bool handle, got {}", handle.handle.name()), + } + } + + fn quantized_tensor(handle: TensorHandle) -> QuantizedTensor { + match handle.handle { + HandleKind::Quantized(handle) => handle, + _ => panic!("Expected quantized handle, got {}", handle.handle.name()), + } + } + + fn float_tensor_handle(tensor: FloatTensor) -> Self::Handle { + HandleKind::Float(tensor) + } + + fn int_tensor_handle(tensor: IntTensor) -> Self::Handle { + HandleKind::Int(tensor) + } + + fn bool_tensor_handle(tensor: BoolTensor) -> Self::Handle { + HandleKind::Bool(tensor) + } + + fn quantized_tensor_handle(tensor: QuantizedTensor) -> Self::Handle { + HandleKind::Quantized(tensor) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use burn_backend::QTensorPrimitive; + + #[test] + fn should_support_dtypes() { + type B = NdArray; + let device = Default::default(); + + assert!(B::supports_dtype(&device, DType::F64)); + assert!(B::supports_dtype(&device, DType::F32)); + assert!(B::supports_dtype(&device, DType::Flex32)); + assert!(B::supports_dtype(&device, DType::I64)); + assert!(B::supports_dtype(&device, DType::I32)); + assert!(B::supports_dtype(&device, DType::I16)); + assert!(B::supports_dtype(&device, DType::I8)); + assert!(B::supports_dtype(&device, DType::U64)); + assert!(B::supports_dtype(&device, DType::U32)); + assert!(B::supports_dtype(&device, DType::U16)); + assert!(B::supports_dtype(&device, DType::U8)); + assert!(B::supports_dtype(&device, DType::Bool(BoolStore::Native))); + assert!(B::supports_dtype( + &device, + DType::QFloat(NdArrayQTensor::default_scheme()) + )); + + assert!(!B::supports_dtype(&device, DType::F16)); + assert!(!B::supports_dtype(&device, DType::BF16)); + // QuantStore::U32 not supported + assert!(!B::supports_dtype( + &device, + DType::QFloat(QuantScheme::default()) + )); + } +} diff --git a/crates/burn/src/element.rs b/crates/burn/src/element.rs new file mode 100644 index 00000000..8485352e --- /dev/null +++ b/crates/burn/src/element.rs @@ -0,0 +1,207 @@ +use burn_backend::Element; +use num_traits::Signed; + +#[cfg(not(feature = "std"))] +#[allow(unused_imports)] +use num_traits::Float; + +use num_traits::Pow; + +use libm::{log1p, log1pf}; + +/// A float element for ndarray backend. +pub trait FloatNdArrayElement: NdArrayElement + Signed + core::cmp::PartialOrd +where + Self: Sized, +{ +} + +/// An int element for ndarray backend. +pub trait IntNdArrayElement: NdArrayElement + core::cmp::PartialOrd {} + +/// A general element for ndarray backend. +pub trait NdArrayElement: + Element + + ndarray::LinalgScalar + + ndarray::ScalarOperand + + ExpElement + + AddAssignElement + + num_traits::FromPrimitive + + core::ops::AddAssign + + core::cmp::PartialEq + + core::ops::Rem +{ +} + +/// A element for ndarray backend that supports exp ops. +pub trait ExpElement { + /// Exponent + fn exp_elem(self) -> Self; + /// Log + fn log_elem(self) -> Self; + /// Log1p + fn log1p_elem(self) -> Self; + /// Powf + fn powf_elem(self, value: f32) -> Self; + /// Powi + fn powi_elem(self, value: i32) -> Self; + /// Sqrt + fn sqrt_elem(self) -> Self; + /// Abs + fn abs_elem(self) -> Self; +} + +/// The addition assignment operator implemented for ndarray elements. +pub trait AddAssignElement { + /// Performs the addition assignment operation. + /// + /// For `bool`, this corresponds to logical OR assignment. + fn add_assign(&mut self, rhs: Rhs); +} + +impl AddAssignElement for E { + fn add_assign(&mut self, rhs: Self) { + *self += rhs; + } +} + +impl AddAssignElement for bool { + fn add_assign(&mut self, rhs: Self) { + *self = *self || rhs; // logical OR for bool + } +} + +/// A quantized element for the ndarray backend. +pub trait QuantElement: NdArrayElement {} + +impl QuantElement for i8 {} + +impl FloatNdArrayElement for f64 {} +impl FloatNdArrayElement for f32 {} + +impl IntNdArrayElement for i64 {} +impl IntNdArrayElement for i32 {} +impl IntNdArrayElement for i16 {} +impl IntNdArrayElement for i8 {} + +impl IntNdArrayElement for u64 {} +impl IntNdArrayElement for u32 {} +impl IntNdArrayElement for u16 {} +impl IntNdArrayElement for u8 {} + +macro_rules! make_float { + ( + $ty:ty, + $log1p:expr + ) => { + impl NdArrayElement for $ty {} + + #[allow(clippy::cast_abs_to_unsigned)] + impl ExpElement for $ty { + #[inline(always)] + fn exp_elem(self) -> Self { + self.exp() + } + + #[inline(always)] + fn log_elem(self) -> Self { + self.ln() + } + + #[inline(always)] + fn log1p_elem(self) -> Self { + $log1p(self) + } + + #[inline(always)] + fn powf_elem(self, value: f32) -> Self { + self.pow(value) + } + + #[inline(always)] + fn powi_elem(self, value: i32) -> Self { + #[cfg(feature = "std")] + let val = self.powi(value); + + #[cfg(not(feature = "std"))] + let val = Self::powf_elem(self, value as f32); + + val + } + + #[inline(always)] + fn sqrt_elem(self) -> Self { + self.sqrt() + } + + #[inline(always)] + fn abs_elem(self) -> Self { + self.abs() + } + } + }; +} +macro_rules! make_int { + ( + $ty:ty, + $abs:expr + ) => { + impl NdArrayElement for $ty {} + + #[allow(clippy::cast_abs_to_unsigned)] + impl ExpElement for $ty { + #[inline(always)] + fn exp_elem(self) -> Self { + (self as f32).exp() as $ty + } + + #[inline(always)] + fn log_elem(self) -> Self { + (self as f32).ln() as $ty + } + + #[inline(always)] + fn log1p_elem(self) -> Self { + log1pf(self as f32) as $ty + } + + #[inline(always)] + fn powf_elem(self, value: f32) -> Self { + (self as f32).pow(value) as $ty + } + + #[inline(always)] + fn powi_elem(self, value: i32) -> Self { + #[cfg(feature = "std")] + let val = f32::powi(self as f32, value) as $ty; + + #[cfg(not(feature = "std"))] + let val = Self::powf_elem(self, value as f32); + + val + } + + #[inline(always)] + fn sqrt_elem(self) -> Self { + (self as f32).sqrt() as $ty + } + + #[inline(always)] + fn abs_elem(self) -> Self { + $abs(self) + } + } + }; +} + +make_float!(f64, log1p); +make_float!(f32, log1pf); + +make_int!(i64, i64::wrapping_abs); +make_int!(i32, i32::wrapping_abs); +make_int!(i16, i16::wrapping_abs); +make_int!(i8, i8::wrapping_abs); +make_int!(u64, |x| x); +make_int!(u32, |x| x); +make_int!(u16, |x| x); +make_int!(u8, |x| x); diff --git a/crates/burn/src/lib.rs b/crates/burn/src/lib.rs new file mode 100644 index 00000000..34a46255 --- /dev/null +++ b/crates/burn/src/lib.rs @@ -0,0 +1,29 @@ +#![cfg_attr(not(feature = "std"), no_std)] +#![warn(missing_docs)] +#![cfg_attr(docsrs, feature(doc_cfg))] + +//! Burn ndarray backend. + +#[cfg(any( + feature = "blas-netlib", + feature = "blas-openblas", + feature = "blas-openblas-system", +))] +extern crate blas_src; + +mod backend; +mod element; +mod ops; +mod parallel; +mod rand; +mod sharing; +mod storage; +mod tensor; + +pub use backend::*; +pub use element::*; +pub(crate) use sharing::*; +pub(crate) use storage::*; +pub use tensor::*; + +extern crate alloc; diff --git a/crates/burn/src/ops/activation.rs b/crates/burn/src/ops/activation.rs new file mode 100644 index 00000000..dea8533d --- /dev/null +++ b/crates/burn/src/ops/activation.rs @@ -0,0 +1,45 @@ +use crate::{ + NdArray, NdArrayStorage, NdArrayTensor, SharedArray, + element::{FloatNdArrayElement, IntNdArrayElement, QuantElement}, + execute_with_numeric_dtype, + ops::NdArrayMathOps, +}; +use burn_backend::{ElementConversion, TensorMetadata, ops::ActivationOps, tensor::FloatTensor}; + +impl ActivationOps + for NdArray +where + NdArrayTensor: From>, + NdArrayTensor: From>, +{ + fn relu(tensor: FloatTensor) -> FloatTensor { + execute_with_numeric_dtype!(tensor, |array| NdArrayMathOps::clamp_min(array, 0.elem())) + } + + /// Sigmoid via ndarray::hpc::activations::sigmoid_f32 (fused F32x16 SIMD). + /// + /// Default impl decomposes into 6 separate ops: neg, exp, add, log, neg, exp. + /// Our version does `1 / (1 + exp(-x))` in one SIMD pass with F32x16. + fn sigmoid(tensor: FloatTensor) -> FloatTensor { + #[cfg(feature = "simd")] + if let NdArrayTensor::F32(ref storage) = tensor { + let view = storage.view(); + if view.is_standard_layout() { + if let Some(input) = view.as_slice() { + let mut output = alloc::vec![0.0f32; input.len()]; + ndarray::hpc::activations::sigmoid_f32(input, &mut output); + let shape: alloc::vec::Vec = view.shape().to_vec(); + let array = ndarray::Array::from_shape_vec(ndarray::IxDyn(&shape), output) + .expect("sigmoid output shape mismatch"); + return NdArrayTensor::F32(NdArrayStorage::Owned(array.into_shared())); + } + } + } + // Fallback: decomposed sigmoid via Backend ops (non-f32 or non-contiguous). + use burn_backend::ops::FloatTensorOps; + let tensor_neg = Self::float_neg(tensor); + let tensor_exp = Self::float_exp(tensor_neg); + let tensor_add = Self::float_add_scalar(tensor_exp, 1.0.into()); + Self::float_recip(tensor_add) + } +} diff --git a/crates/burn/src/ops/adaptive_avgpool.rs b/crates/burn/src/ops/adaptive_avgpool.rs new file mode 100644 index 00000000..baaee09f --- /dev/null +++ b/crates/burn/src/ops/adaptive_avgpool.rs @@ -0,0 +1,103 @@ +use crate::{ + SharedArray, element::FloatNdArrayElement, iter_range_par, run_par, sharing::UnsafeSharedRef, +}; +use burn_backend::ElementConversion; +use ndarray::Array4; + +#[cfg(not(feature = "std"))] +#[allow(unused_imports)] +use num_traits::Float; + +pub(crate) fn adaptive_avg_pool2d( + x: SharedArray, + output_size: [usize; 2], +) -> SharedArray { + let [batch_size, channels, input_height, input_width] = x.shape().try_into().unwrap(); + + let mut output = Array4::from_elem( + (batch_size, channels, output_size[0], output_size[1]), + 0.elem(), + ); + let unsafe_shared_out = UnsafeSharedRef::new(&mut output); + + run_par!(|| { + iter_range_par!(0, batch_size * channels).for_each(|k| unsafe { + let b = k / channels; + let c = k % channels; + + let output = unsafe_shared_out.get(); + for h in 0..output_size[0] { + for w in 0..output_size[1] { + let ih_start = start_index(h, output_size[0], input_height); + let ih_end = end_index(h, output_size[0], input_height); + let iw_start = start_index(w, output_size[1], input_width); + let iw_end = end_index(w, output_size[1], input_width); + + let mut sum_val: E = 0.elem(); + + for ih in ih_start..ih_end { + for iw in iw_start..iw_end { + sum_val += x[[b, c, ih, iw]]; + } + } + + let count: E = (((ih_end - ih_start) * (iw_end - iw_start)) as i32).elem(); + output[[b, c, h, w]] = sum_val / count.elem(); + } + } + }) + }); + + output.into_dyn().into_shared() +} + +pub(crate) fn adaptive_avg_pool2d_backward( + x: SharedArray, + grad: SharedArray, +) -> SharedArray { + let [_, _, input_height, input_width] = x.shape().try_into().unwrap(); + let [batch_size, channels, output_height, output_width] = grad.shape().try_into().unwrap(); + + let mut output_grad = + Array4::from_elem((batch_size, channels, input_height, input_width), 0.elem()); + let unsafe_shared_out = UnsafeSharedRef::new(&mut output_grad); + + run_par!(|| { + iter_range_par!(0, batch_size * channels).for_each(|k| unsafe { + let b = k / channels; + let c = k % channels; + + let output_grad = unsafe_shared_out.get(); + for oh in 0..output_height { + for ow in 0..output_width { + let ih_start = start_index(oh, output_height, input_height); + let ih_end = end_index(oh, output_height, input_height); + + let iw_start = start_index(ow, output_width, input_width); + let iw_end = end_index(ow, output_width, input_width); + + let count: E = (((ih_end - ih_start) * (iw_end - iw_start)) as i32).elem(); + + for ih in ih_start..ih_end { + for iw in iw_start..iw_end { + output_grad[[b, c, ih, iw]] += grad[[b, c, oh, ow]] / count.elem(); + } + } + } + } + }) + }); + + output_grad.into_dyn().into_shared() +} + +fn start_index(output_size_index: usize, output_size: usize, input_size: usize) -> usize { + ((output_size_index as f32 * input_size as f32) / output_size as f32).floor() as usize +} + +fn end_index(output_size_index: usize, output_size: usize, input_size: usize) -> usize { + let index = + (((output_size_index + 1) as f32 * input_size as f32) / output_size as f32).ceil() as usize; + + usize::min(index, input_size) +} diff --git a/crates/burn/src/ops/avgpool.rs b/crates/burn/src/ops/avgpool.rs new file mode 100644 index 00000000..4d015dd9 --- /dev/null +++ b/crates/burn/src/ops/avgpool.rs @@ -0,0 +1,172 @@ +use crate::{ + SharedArray, element::FloatNdArrayElement, iter_range_par, run_par, sharing::UnsafeSharedRef, +}; + +use burn_backend::ElementConversion; +use burn_backend::ops::conv::calculate_pool_output_size; +use ndarray::Array4; + +pub(crate) fn avg_pool2d( + x: SharedArray, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + count_include_pad: bool, + ceil_mode: bool, +) -> SharedArray { + let [kernel_height, kernel_width] = kernel_size; + let [padding_height, padding_width] = padding; + let [stride_height, stride_width] = stride; + let [batch_size, channels, x_height, x_width] = x.shape().try_into().unwrap(); + + let out_height = calculate_pool_output_size( + kernel_height, + stride_height, + padding_height, + 1, + x_height, + ceil_mode, + ); + let out_width = calculate_pool_output_size( + kernel_width, + stride_width, + padding_width, + 1, + x_width, + ceil_mode, + ); + + // Padded input bounds (for count_include_pad calculation) + let padded_height = x_height + 2 * padding_height; + let padded_width = x_width + 2 * padding_width; + + let mut output = Array4::from_elem((batch_size, channels, out_height, out_width), 0.elem()); + let unsafe_shared_out = UnsafeSharedRef::new(&mut output); + + run_par!(|| { + iter_range_par!(0, batch_size * channels).for_each(|k| unsafe { + let b = k / channels; + let c = k % channels; + + let output = unsafe_shared_out.get(); + + for oh in 0..out_height { + for ow in 0..out_width { + let mut sum_val: E = 0.elem(); + let mut valid_count = 0usize; + let mut padded_count = 0usize; + + for kh in 0..kernel_height { + let ih = oh * stride_height + kh; + + for kw in 0..kernel_width { + let iw = ow * stride_width + kw; + + // Check if within padded bounds (excludes ceil_mode extensions) + if ih < padded_height && iw < padded_width { + padded_count += 1; + + // Check if within valid (non-padding) input bounds + if ih >= padding_height + && ih < x_height + padding_height + && iw >= padding_width + && iw < x_width + padding_width + { + let ih_valid = ih - padding_height; + let iw_valid = iw - padding_width; + sum_val += x[[b, c, ih_valid, iw_valid]]; + valid_count += 1; + } + } + } + } + + // count_include_pad: count positions within padded bounds (not ceil_mode extensions) + // !count_include_pad: count only valid (non-padding) positions + let count: E = if count_include_pad { + (padded_count as i32).elem() + } else { + (valid_count as i32).elem() + }; + + output[[b, c, oh, ow]] = sum_val / count; + } + } + }) + }); + + output.into_dyn().into_shared() +} + +pub(crate) fn avg_pool2d_backward( + x: SharedArray, + grad: SharedArray, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + count_include_pad: bool, + _ceil_mode: bool, +) -> SharedArray { + let [kernel_height, kernel_width] = kernel_size; + let [stride_height, stride_width] = stride; + let [padding_height, padding_width] = padding; + let [batch_size, channels, x_height, x_width] = x.shape().try_into().unwrap(); + let [_batch_size, _channels, out_height, out_width] = grad.shape().try_into().unwrap(); + + // Padded input bounds (for count_include_pad calculation) + let padded_height = x_height + 2 * padding_height; + let padded_width = x_width + 2 * padding_width; + + let mut output_grad = Array4::from_elem((batch_size, channels, x_height, x_width), 0.elem()); + let unsafe_shared_grad = UnsafeSharedRef::new(&mut output_grad); + + run_par!(|| { + iter_range_par!(0, batch_size * channels).for_each(|k| unsafe { + let b = k / channels; + let c = k % channels; + + let output_grad = unsafe_shared_grad.get(); + + for oh in 0..out_height { + for ow in 0..out_width { + let ih_start_kernel = oh * stride_height; + let iw_start_kernel = ow * stride_width; + + let ih_end_kernel = ih_start_kernel + kernel_height; + let iw_end_kernel = iw_start_kernel + kernel_width; + + // Clip to valid input bounds (for gradient distribution) + let ih_start = usize::max(ih_start_kernel, padding_height); + let iw_start = usize::max(iw_start_kernel, padding_width); + let ih_end = usize::min(ih_end_kernel, x_height + padding_height); + let iw_end = usize::min(iw_end_kernel, x_width + padding_width); + + // Calculate count based on count_include_pad + let count = if count_include_pad { + // Count positions within padded bounds (not ceil_mode extensions) + let ih_start_padded = ih_start_kernel; + let iw_start_padded = iw_start_kernel; + let ih_end_padded = usize::min(ih_end_kernel, padded_height); + let iw_end_padded = usize::min(iw_end_kernel, padded_width); + (ih_end_padded - ih_start_padded) * (iw_end_padded - iw_start_padded) + } else { + // Count only valid (non-padding) positions + (ih_end - ih_start) * (iw_end - iw_start) + }; + + for ih in ih_start..ih_end { + for iw in iw_start..iw_end { + let ih = ih - padding_height; + let iw = iw - padding_width; + + output_grad[[b, c, ih, iw]] += + grad[[b, c, oh, ow]] / (count as i32).elem(); + } + } + } + } + }) + }); + + output_grad.into_dyn().into_shared() +} diff --git a/crates/burn/src/ops/base.rs b/crates/burn/src/ops/base.rs new file mode 100644 index 00000000..5d2ce429 --- /dev/null +++ b/crates/burn/src/ops/base.rs @@ -0,0 +1,1448 @@ +use alloc::{vec, vec::Vec}; +use burn_backend::element::{Element, ElementConversion}; +#[cfg(feature = "simd")] +use burn_backend::{DType, quantization::QuantValue}; +use core::fmt::Debug; +use core::marker::PhantomData; +use ndarray::IntoDimension; +use ndarray::SliceInfo; +use ndarray::Zip; +use ndarray::s; +use ndarray::{Array2, ArrayD}; +use num_traits::Signed; +#[cfg(feature = "simd")] +use paste::paste; + +#[cfg(not(feature = "std"))] +#[allow(unused_imports)] +use num_traits::Float; + +#[cfg(feature = "simd")] +use crate::ops::simd::{ + binary::try_binary_simd, + binary_elemwise::{ + VecAdd, VecBitAnd, VecBitOr, VecBitXor, VecClamp, VecDiv, VecMax, VecMin, VecMul, VecSub, + try_binary_scalar_simd, + }, + cmp::{ + VecEquals, VecGreater, VecGreaterEq, VecLower, VecLowerEq, try_cmp_scalar_simd, + try_cmp_simd, + }, + unary::{RecipVec, VecAbs, VecBitNot, try_unary_simd}, +}; +use crate::reshape; +use crate::{ + IntNdArrayElement, ShapeOps, + ops::macros::{ + cummax_dim, cummin_dim, cumprod_dim, cumsum_dim, keepdim, mean_dim, prod_dim, sum_dim, + }, +}; +use crate::{SharedArray, element::NdArrayElement}; +use burn_backend::ops::unfold::calculate_unfold_shape; +use burn_backend::{Shape, Slice}; +use ndarray::ArrayView; +use ndarray::Axis; +use ndarray::Dim; +use ndarray::IxDyn; +use ndarray::SliceInfoElem; + +pub struct NdArrayOps { + e: PhantomData, +} + +pub(crate) struct NdArrayMathOps { + e: PhantomData, +} + +impl NdArrayOps +where + E: Copy + Debug + Element + crate::AddAssignElement, +{ + pub fn slice(tensor: ArrayView, slices: &[Slice]) -> SharedArray { + let slices = Self::to_slice_args_with_steps(slices, tensor.shape().num_dims()); + tensor.slice_move(slices.as_slice()).to_shared() + } + + pub fn slice_assign( + tensor: SharedArray, + slices: &[Slice], + value: SharedArray, + ) -> SharedArray { + let slices = Self::to_slice_args_with_steps(slices, tensor.shape().num_dims()); + let mut array = tensor.into_owned(); + array.slice_mut(slices.as_slice()).assign(&value); + array.into_shared() + } + + pub fn mask_where( + tensor: SharedArray, + mask: SharedArray, + source: SharedArray, + ) -> SharedArray { + let tensor = tensor.broadcast(mask.dim()).unwrap(); + let source = source.broadcast(mask.dim()).unwrap(); + Zip::from(&tensor) + .and(&mask) + .and(&source) + .map_collect(|&x, &mask_val, &y| if mask_val { y } else { x }) + .into_shared() + } + + pub fn mask_fill(tensor: SharedArray, mask: SharedArray, value: E) -> SharedArray { + // Use into_owned() instead of clone() - only copies if shared, avoids copy if unique + let mut output = tensor.into_owned(); + let broadcast_mask = mask.broadcast(output.dim()).unwrap(); + Zip::from(&mut output) + .and(&broadcast_mask) + .for_each(|out, &mask_val| { + if mask_val { + *out = value; + } + }); + output.into_shared() + } + + pub fn gather( + dim: usize, + mut tensor: SharedArray, + mut indices: SharedArray, + ) -> SharedArray { + let ndims = tensor.shape().num_dims(); + if dim != ndims - 1 { + tensor.swap_axes(ndims - 1, dim); + indices.swap_axes(ndims - 1, dim); + } + let (shape_tensor, shape_indices) = (tensor.shape(), indices.shape().into_shape()); + let (size_tensor, size_index) = (shape_tensor[ndims - 1], shape_indices[ndims - 1]); + let batch_size = Self::gather_batch_size(shape_tensor, &shape_indices); + + let indices = NdArrayOps::reshape(indices, Shape::new([batch_size, size_index])); + let tensor = NdArrayOps::reshape(tensor, Shape::new([batch_size, size_tensor])); + let mut output = Array2::from_elem((batch_size, size_index), 0.elem::()); + + for b in 0..batch_size { + let indices = indices.slice(s!(b, ..)); + for (i, index) in indices.iter().enumerate() { + output[[b, i]] = tensor[[b, index.elem::() as usize]]; + } + } + + let mut output = NdArrayOps::reshape(output.into_shared().into_dyn(), shape_indices); + + if dim != ndims - 1 { + output.swap_axes(ndims - 1, dim); + } + + output + } + + pub fn scatter( + dim: usize, + mut tensor: SharedArray, + mut indices: SharedArray, + mut value: SharedArray, + ) -> SharedArray { + let ndims = tensor.shape().num_dims(); + if dim != ndims - 1 { + tensor.swap_axes(ndims - 1, dim); + indices.swap_axes(ndims - 1, dim); + value.swap_axes(ndims - 1, dim); + } + + let (shape_tensor, shape_indices, shape_value) = + (tensor.shape().into_shape(), indices.shape(), value.shape()); + let (size_tensor, size_index, size_value) = ( + shape_tensor[ndims - 1], + shape_indices[ndims - 1], + shape_value[ndims - 1], + ); + let batch_size = Self::gather_batch_size(&shape_tensor, shape_indices); + + if shape_value != shape_indices { + panic!( + "Invalid dimension: the shape of the index tensor should be the same as the value \ + tensor: Index {:?} value {:?}", + shape_indices, shape_value + ); + } + + let indices = NdArrayOps::reshape(indices, Shape::new([batch_size, size_index])); + let value = NdArrayOps::reshape(value, Shape::new([batch_size, size_value])); + let mut tensor = NdArrayOps::reshape(tensor, Shape::new([batch_size, size_tensor])); + + for b in 0..batch_size { + let indices = indices.slice(s!(b, ..)); + + for (i, index) in indices.iter().enumerate() { + let index = index.elem::() as usize; + tensor[[b, index]].add_assign(value[[b, i]]); + } + } + + let mut output = NdArrayOps::reshape(tensor.into_shared().into_dyn(), shape_tensor); + if dim != ndims - 1 { + output.swap_axes(ndims - 1, dim); + } + output + } + + fn gather_batch_size(shape_tensor: &[usize], shape_indices: &[usize]) -> usize { + let ndims = shape_tensor.num_dims(); + let mut batch_size = 1; + + for i in 0..ndims - 1 { + if shape_tensor[i] != shape_indices[i] { + panic!( + "Unsupported dimension, only the last dimension can differ: Tensor {:?} Index \ + {:?}", + shape_tensor, shape_indices + ); + } + batch_size *= shape_indices[i]; + } + + batch_size + } + + pub fn reshape(tensor: SharedArray, shape: Shape) -> SharedArray { + reshape!( + ty E, + shape shape, + array tensor, + d shape.num_dims() + ) + } + + pub(crate) fn concatenate( + arrays: &[ndarray::ArrayView], + dim: usize, + ) -> SharedArray { + let array = ndarray::concatenate(Axis(dim), arrays) + .unwrap() + .into_shared(); + + // Transform column-major layout into row-major (standard) layout. (fix #1053) + // Get shape first (via reference), then pass ownership to avoid clone + let shape = array.shape().into_shape(); + Self::reshape(array, shape) + } + + pub fn cat(tensors: Vec>, dim: usize) -> SharedArray { + let arrays: Vec<_> = tensors.iter().map(|t| t.view()).collect(); + Self::concatenate(&arrays, dim) + } + + #[allow(clippy::wrong_self_convention)] + fn to_slice_args_with_steps( + burn_slices: &[burn_backend::Slice], + ndims: usize, + ) -> Vec { + let mut slices = vec![SliceInfoElem::NewAxis; ndims]; + + for i in 0..ndims { + slices[i] = if i < burn_slices.len() { + let slice = &burn_slices[i]; + + // Check for empty range (would result in no elements) + if let Some(end) = slice.end + && slice.start == end + { + SliceInfoElem::Slice { + start: 0, + end: Some(0), + step: 1, + } + } else { + // Pass slice parameters directly to ndarray + // ndarray handles both positive and negative steps correctly: + // - Positive step: iterates forward from start + // - Negative step: iterates backward from the last element in range + SliceInfoElem::Slice { + start: slice.start, + end: slice.end, + step: slice.step, + } + } + } else { + // Dimension not specified in slices - use full range + SliceInfoElem::Slice { + start: 0, + end: None, + step: 1, + } + } + } + + slices + } + + pub fn swap_dims(mut tensor: SharedArray, dim1: usize, dim2: usize) -> SharedArray { + tensor.swap_axes(dim1, dim2); + + tensor + } + + pub fn permute(tensor: SharedArray, axes: &[usize]) -> SharedArray { + tensor.permuted_axes(axes.into_dimension()) + } + + /// Broadcasts the tensor to the given shape + pub(crate) fn expand(tensor: SharedArray, shape: Shape) -> SharedArray { + tensor + .broadcast(shape.into_dimension()) + .expect("The shapes should be broadcastable") + // need to convert view to owned array because NdArrayTensor expects owned array + // and try_into_owned_nocopy() panics for broadcasted arrays (zero strides) + .into_owned() + .into_shared() + } + + pub fn flip(tensor: SharedArray, axes: &[usize]) -> SharedArray { + let slice_items: Vec<_> = (0..tensor.shape().num_dims()) + .map(|i| { + if axes.contains(&i) { + SliceInfoElem::Slice { + start: 0, + end: None, + step: -1, + } + } else { + SliceInfoElem::Slice { + start: 0, + end: None, + step: 1, + } + } + }) + .collect(); + let slice_info = + SliceInfo::, IxDyn, IxDyn>::try_from(slice_items).unwrap(); + tensor.slice(slice_info).into_owned().into_shared() + } + + /// Unfold windows along a dimension. + /// + /// # Warning + /// + /// This is a copy impl; `ndarray` doesn't expose the layout machinery + /// necessary to build the stride view. + /// + /// Returns a copy of the tensor with all complete windows of size `size` in dimension `dim`; + /// where windows are advanced by `step` at each index. + /// + /// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`. + /// + /// # Arguments + /// + /// * `tensor` - The input tensor to unfold; of shape ``[pre=..., dim shape, post=...]`` + /// * `dim` - the dimension to unfold. + /// * `size` - the size of each unfolded window. + /// * `step` - the step between each window. + /// + /// # Returns + /// + /// A tensor view with shape ``[pre=..., windows, post=..., size]``. + #[allow(unused)] + pub(crate) fn unfold( + tensor: SharedArray, + dim: usize, + size: usize, + step: usize, + ) -> SharedArray { + let result_shape = calculate_unfold_shape(tensor.shape(), dim, size, step); + let windows = result_shape[dim]; + + let mut slices = vec![Slice::new(0, None, 1); tensor.shape().len()]; + let new_axis = slices.len(); + + let mut stack = Vec::with_capacity(windows); + for widx in 0..windows { + let start = widx * step; + let end = start + size; + slices[dim] = Slice::new(start as isize, Some(end as isize), 1); + + let mut window_slice = + tensor.slice(Self::to_slice_args_with_steps(&slices, slices.len()).as_slice()); + window_slice.insert_axis_inplace(Axis(new_axis)); + window_slice.swap_axes(dim, new_axis); + + stack.push(window_slice); + } + Self::concatenate(&stack, dim) + } +} + +#[cfg(feature = "simd")] +macro_rules! dispatch_binary_simd { + (noq, $elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ + paste! { + let simd = match $elem::dtype() { + $(DType::[<$ty:upper>] => try_binary_simd::<$elem, $elem, $ty, $ty, $op>($lhs, $rhs),)* + _ => Err(($lhs, $rhs)), + }; + match simd { + Ok(out) => return out, + Err(args) => args, + } + } + }}; + ($elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ + paste! { + let simd = match $elem::dtype() { + $(DType::[<$ty:upper>] => try_binary_simd::<$elem, $elem, $ty, $ty, $op>($lhs, $rhs),)* + DType::QFloat(strategy) => match strategy.value { + QuantValue::Q8F | QuantValue::Q8S => try_binary_simd::<$elem, $elem, i8, i8, $op>($lhs, $rhs), + _ => Err(($lhs, $rhs)), + }, + _ => Err(($lhs, $rhs)), + }; + match simd { + Ok(out) => return out, + Err(args) => args, + } + } + }}; +} + +#[cfg(not(feature = "simd"))] +macro_rules! dispatch_binary_simd { + (noq, $elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ ($lhs, $rhs) }}; + ($elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ ($lhs, $rhs) }}; +} + +#[cfg(feature = "simd")] +macro_rules! dispatch_binary_scalar_simd { + (noq, $elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ + paste! { + let simd = match $elem::dtype() { + $(DType::[<$ty:upper>] => try_binary_scalar_simd::<$elem, $elem, $ty, $ty, $op>($lhs, $rhs),)* + _ => Err($lhs), + }; + match simd { + Ok(out) => return out, + Err(args) => args, + } + } + }}; + ($elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ + paste! { + let simd = match $elem::dtype() { + $(DType::[<$ty:upper>] => try_binary_scalar_simd::<$elem, $elem, $ty, $ty, $op>($lhs, $rhs),)* + DType::QFloat(strategy) => match strategy.value { + QuantValue::Q8F | QuantValue::Q8S => try_binary_scalar_simd::<$elem, $elem, i8, i8, $op>($lhs, $rhs), + QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S | QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1 => Err($lhs) + }, + _ => Err($lhs), + }; + match simd { + Ok(out) => return out, + Err(args) => args, + } + } + }}; +} + +#[cfg(not(feature = "simd"))] +macro_rules! dispatch_binary_scalar_simd { + (noq, $elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ $lhs }}; + ($elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ $lhs }}; +} + +#[cfg(feature = "simd")] +macro_rules! dispatch_cmp_simd { + ($elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ + paste! { + let simd = match $elem::dtype() { + $(DType::[<$ty:upper>] => try_cmp_simd::<$elem, $ty, $op>($lhs, $rhs),)* + DType::QFloat(strategy) => match strategy.value { + QuantValue::Q8F | QuantValue::Q8S => try_cmp_simd::<$elem, i8, $op>($lhs, $rhs), + QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S | QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1 => Err(($lhs, $rhs)) + }, + _ => Err(($lhs, $rhs)), + }; + match simd { + Ok(out) => return out, + Err(args) => args, + } + } + }}; +} + +#[cfg(not(feature = "simd"))] +macro_rules! dispatch_cmp_simd { + ($elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ ($lhs, $rhs) }}; +} + +#[cfg(feature = "simd")] +macro_rules! dispatch_cmp_scalar_simd { + ($elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ + paste! { + let simd = match $elem::dtype() { + $(DType::[<$ty:upper>] => try_cmp_scalar_simd::<$elem, $ty, $op>($lhs, $rhs),)* + DType::QFloat(strategy) => match strategy.value { + QuantValue::Q8F | QuantValue::Q8S => try_cmp_scalar_simd::<$elem, i8, $op>($lhs, $rhs), + QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S | QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1 => Err($lhs) + }, + _ => Err($lhs), + }; + match simd { + Ok(out) => return out, + Err(args) => args, + } + } + }}; +} + +#[cfg(not(feature = "simd"))] +macro_rules! dispatch_cmp_scalar_simd { + ($elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ $lhs }}; +} + +#[cfg(feature = "simd")] +macro_rules! dispatch_unary_simd { + ($elem: ty, $op: ty, $lhs: expr, $($ty: ty),*) => {{ + paste! { + let simd = match $elem::dtype() { + $(DType::[<$ty:upper>] => try_unary_simd::<$elem, $elem, $ty, $ty, $op>($lhs),)* + _ => Err($lhs), + }; + match simd { + Ok(out) => return out, + Err(args) => args, + } + } + }}; +} + +#[cfg(not(feature = "simd"))] +macro_rules! dispatch_unary_simd { + ($elem: ty, $op: ty, $lhs: expr, $($ty: ty),*) => {{ $lhs }}; +} + +// Helper function to broadcast two tensors to a common shape for comparison operations +// Returns broadcasted views that can be safely zipped +fn broadcast_for_comparison<'a, E: Copy, S1, S2>( + lhs: &'a ndarray::ArrayBase, + rhs: &'a ndarray::ArrayBase, +) -> ( + ndarray::ArrayView<'a, E, ndarray::IxDyn>, + ndarray::ArrayView<'a, E, ndarray::IxDyn>, +) +where + S1: ndarray::Data, + S2: ndarray::Data, +{ + // Get shapes + let lhs_shape = lhs.shape(); + let rhs_shape = rhs.shape(); + + // Compute broadcast shape using ndarray's broadcast compatibility rules + let ndims = lhs_shape.len().max(rhs_shape.len()); + let mut broadcast_shape = vec![1; ndims]; + + for i in 0..ndims { + let lhs_dim = if i < lhs_shape.len() { + lhs_shape[lhs_shape.len() - 1 - i] + } else { + 1 + }; + let rhs_dim = if i < rhs_shape.len() { + rhs_shape[rhs_shape.len() - 1 - i] + } else { + 1 + }; + + if lhs_dim == rhs_dim { + broadcast_shape[ndims - 1 - i] = lhs_dim; + } else if lhs_dim == 1 { + broadcast_shape[ndims - 1 - i] = rhs_dim; + } else if rhs_dim == 1 { + broadcast_shape[ndims - 1 - i] = lhs_dim; + } else { + panic!( + "Incompatible shapes for broadcasting: {:?} and {:?}", + lhs_shape, rhs_shape + ); + } + } + + // Create IxDyn from broadcast shape + let broadcast_dim = ndarray::IxDyn(&broadcast_shape); + + // Broadcast both arrays + let lhs_broadcast = lhs + .broadcast(broadcast_dim.clone()) + .expect("Failed to broadcast lhs"); + let rhs_broadcast = rhs + .broadcast(broadcast_dim) + .expect("Failed to broadcast rhs"); + + (lhs_broadcast, rhs_broadcast) +} + +impl NdArrayMathOps +where + E: Copy + NdArrayElement, +{ + pub fn add(lhs: SharedArray, rhs: SharedArray) -> SharedArray { + let (lhs, rhs) = dispatch_binary_simd!( + E, VecAdd, lhs, rhs, u8, i8, u16, i16, u32, i32, f32, u64, i64, f64 + ); + + let array = &lhs + &rhs; + array.into_shared() + } + + pub fn add_scalar(lhs: SharedArray, rhs: E) -> SharedArray { + let lhs = dispatch_binary_scalar_simd!( + E, + VecAdd, + lhs, + rhs.elem(), + u8, + i8, + u16, + i16, + u32, + i32, + f32, + u64, + i64, + f64 + ); + + let array = lhs + rhs; + array.into_shared() + } + + pub fn sub(lhs: SharedArray, rhs: SharedArray) -> SharedArray { + let (lhs, rhs) = dispatch_binary_simd!( + E, VecSub, lhs, rhs, u8, i8, u16, i16, u32, i32, f32, u64, i64, f64 + ); + + let array = lhs - rhs; + array.into_shared() + } + + pub fn sub_scalar(lhs: SharedArray, rhs: E) -> SharedArray { + let lhs = dispatch_binary_scalar_simd!( + E, + VecSub, + lhs, + rhs.elem(), + u8, + i8, + u16, + i16, + u32, + i32, + f32, + u64, + i64, + f64 + ); + + let array = lhs - rhs; + array.into_shared() + } + + pub fn mul(lhs: SharedArray, rhs: SharedArray) -> SharedArray { + let (lhs, rhs) = + dispatch_binary_simd!(noq, E, VecMul, lhs, rhs, u16, i16, u32, i32, f32, f64); + + let array = lhs * rhs; + array.into_shared() + } + + pub fn mul_scalar(lhs: SharedArray, rhs: E) -> SharedArray { + let lhs = dispatch_binary_scalar_simd!( + noq, + E, + VecMul, + lhs, + rhs.elem(), + u16, + i16, + u32, + i32, + f32, + f64 + ); + + let array = lhs * rhs; + array.into_shared() + } + + pub fn div(lhs: SharedArray, rhs: SharedArray) -> SharedArray { + let (lhs, rhs) = dispatch_binary_simd!(noq, E, VecDiv, lhs, rhs, f32, f64); + + let array = lhs / rhs; + array.into_shared() + } + + pub fn div_scalar(lhs: SharedArray, rhs: E) -> SharedArray { + let lhs = dispatch_binary_scalar_simd!(noq, E, VecDiv, lhs, rhs.elem(), f32, f64); + + let array = lhs / rhs; + array.into_shared() + } + + pub fn remainder(lhs: SharedArray, rhs: SharedArray) -> SharedArray { + // Use into_owned() instead of clone() - only copies if shared, avoids copy if unique + let mut out = lhs.into_owned(); + Zip::from(&mut out).and(&rhs).for_each(|out_elem, &b| { + // out_elem holds lhs value; read it before overwriting with remainder + let a_f = (*out_elem).to_f64(); + let b_f = b.to_f64(); + let r = a_f - b_f * (a_f / b_f).floor(); + *out_elem = r.elem(); + }); + out.into_shared() + } + + pub fn remainder_scalar(lhs: SharedArray, rhs: E) -> SharedArray + where + E: core::ops::Rem, + { + let array = lhs.mapv(|x| ((x % rhs) + rhs) % rhs); + array.into_shared() + } + + pub fn recip(tensor: SharedArray) -> SharedArray { + let tensor = dispatch_unary_simd!(E, RecipVec, tensor, f32); + + let array = tensor.map(|x| 1.elem::() / *x); + array.into_shared() + } + + /// Sum all elements - zero-copy for borrowed storage. + pub fn sum_view(view: ArrayView<'_, E, IxDyn>) -> SharedArray { + let sum = view.sum(); + ArrayD::from_elem(IxDyn(&[1]), sum).into_shared() + } + + /// Mean of all elements - zero-copy for borrowed storage. + pub fn mean_view(view: ArrayView<'_, E, IxDyn>) -> SharedArray { + let mean = view.mean().unwrap(); + ArrayD::from_elem(IxDyn(&[1]), mean).into_shared() + } + + /// Product of all elements - zero-copy for borrowed storage. + pub fn prod_view(view: ArrayView<'_, E, IxDyn>) -> SharedArray { + let prod = view.iter().fold(E::one(), |acc, &x| acc * x); + ArrayD::from_elem(IxDyn(&[1]), prod).into_shared() + } + + pub fn mean_dim(tensor: SharedArray, dim: usize) -> SharedArray { + let ndims = tensor.shape().num_dims(); + match ndims { + d if (1..=6).contains(&d) => keepdim!(dim, tensor, mean), + _ => panic!("Dim not supported {ndims}"), + } + } + + pub fn sum_dim(tensor: SharedArray, dim: usize) -> SharedArray { + let ndims = tensor.shape().num_dims(); + match ndims { + d if (1..=6).contains(&d) => keepdim!(dim, tensor, sum), + _ => panic!("Dim not supported {ndims}"), + } + } + + pub fn prod_dim(tensor: SharedArray, dim: usize) -> SharedArray { + let ndims = tensor.shape().num_dims(); + match ndims { + d if (1..=6).contains(&d) => keepdim!(dim, tensor, prod), + _ => panic!("Dim not supported {ndims}"), + } + } + + pub fn cumsum(tensor: SharedArray, dim: usize) -> SharedArray { + cumsum_dim(tensor, dim) + } + + pub fn cumprod(tensor: SharedArray, dim: usize) -> SharedArray { + cumprod_dim(tensor, dim) + } + + pub fn select( + tensor: SharedArray, + dim: usize, + indices: SharedArray, + ) -> SharedArray { + let array = tensor.select( + Axis(dim), + &indices + .into_iter() + .map(|i| i.elem::() as usize) + .collect::>(), + ); + + array.into_shared() + } + + pub fn select_assign( + tensor: SharedArray, + dim: usize, + indices: SharedArray, + value: SharedArray, + ) -> SharedArray { + let mut output_array = tensor.into_owned(); + + for (index_value, index) in indices.into_iter().enumerate() { + let mut view = output_array.index_axis_mut(Axis(dim), index.elem::() as usize); + let value = value.index_axis(Axis(dim), index_value); + + view.zip_mut_with(&value, |a, b| *a += *b); + } + + output_array.into_shared() + } + + pub(crate) fn elementwise_op( + lhs: SharedArray, + rhs: SharedArray, + var_name: impl FnMut(&E, &OtherE) -> E, + ) -> SharedArray { + let lhs = lhs.broadcast(rhs.dim()).unwrap_or(lhs.view()); + let rhs = rhs.broadcast(lhs.dim()).unwrap_or(rhs.view()); + + Zip::from(lhs).and(rhs).map_collect(var_name).into_shared() + } + + pub(crate) fn elementwise_op_scalar( + lhs: SharedArray, + var_name: impl FnMut(E) -> E, + ) -> SharedArray { + lhs.mapv(var_name).into_shared() + } + + pub(crate) fn abs(tensor: SharedArray) -> SharedArray { + let tensor = dispatch_unary_simd!(E, VecAbs, tensor, i8, i16, i32, f32, f64); + + tensor.mapv_into(|a| a.abs_elem()).into_shared() + } + + pub(crate) fn equal(lhs: SharedArray, rhs: SharedArray) -> SharedArray { + let (lhs, rhs) = dispatch_cmp_simd!( + E, VecEquals, lhs, rhs, u8, i8, u16, i16, u32, f32, i32, u64, i64, f64 + ); + + // Use the helper to broadcast both arrays to a common shape + let (lhs_broadcast, rhs_broadcast) = broadcast_for_comparison(&lhs, &rhs); + // Now we can safely zip and compare + Zip::from(&lhs_broadcast) + .and(&rhs_broadcast) + .map_collect(|&lhs, &rhs| lhs == rhs) + .into_shared() + } + + pub(crate) fn equal_elem(lhs: SharedArray, rhs: E) -> SharedArray { + let lhs = dispatch_cmp_scalar_simd!( + E, + VecEquals, + lhs, + rhs.elem(), + u8, + i8, + u16, + i16, + u32, + f32, + i32, + u64, + i64, + f64 + ); + + lhs.mapv(|a| a == rhs).into_shared() + } + + pub(crate) fn sign_op(tensor: SharedArray) -> SharedArray + where + E: Signed, + { + let zero = 0.elem(); + let one = 1.elem::(); + + tensor + .mapv(|x| { + if x == zero { + zero + } else { + match x.is_positive() { + true => one, + false => -one, + } + } + }) + .into_shared() + } +} + +impl NdArrayMathOps +where + E: Copy + NdArrayElement + PartialOrd, +{ + /// Max of all elements - zero-copy for borrowed storage. + pub fn max_view(view: ArrayView<'_, E, IxDyn>) -> SharedArray { + let max = view + .iter() + .copied() + .reduce(|a, b| if a > b { a } else { b }) + .expect("Cannot compute max of empty tensor"); + ArrayD::from_elem(IxDyn(&[1]), max).into_shared() + } + + /// Min of all elements - zero-copy for borrowed storage. + pub fn min_view(view: ArrayView<'_, E, IxDyn>) -> SharedArray { + let min = view + .iter() + .copied() + .reduce(|a, b| if a < b { a } else { b }) + .expect("Cannot compute min of empty tensor"); + ArrayD::from_elem(IxDyn(&[1]), min).into_shared() + } + + /// Argmax along dimension - zero-copy for borrowed storage. + pub fn argmax_view( + view: ArrayView<'_, E, IxDyn>, + dim: usize, + ) -> SharedArray { + arg_view(view, dim, CmpType::Max) + } + + /// Argmin along dimension - zero-copy for borrowed storage. + pub fn argmin_view( + view: ArrayView<'_, E, IxDyn>, + dim: usize, + ) -> SharedArray { + arg_view(view, dim, CmpType::Min) + } + + pub fn cummin(tensor: SharedArray, dim: usize) -> SharedArray { + cummin_dim(tensor, dim) + } + + pub fn cummax(tensor: SharedArray, dim: usize) -> SharedArray { + cummax_dim(tensor, dim) + } + + pub fn argmax( + tensor: SharedArray, + dim: usize, + ) -> SharedArray { + arg(tensor, dim, CmpType::Max) + } + + pub fn argmin( + tensor: SharedArray, + dim: usize, + ) -> SharedArray { + arg(tensor, dim, CmpType::Min) + } + + pub fn clamp_min(tensor: SharedArray, min: E) -> SharedArray { + let mut tensor = dispatch_binary_scalar_simd!( + E, + VecMax, + tensor, + min.elem(), + u8, + i8, + u16, + i16, + u32, + i32, + f32, + u64, + i64, + f64 + ); + + tensor.mapv_inplace(|x| match x < min { + true => min, + false => x, + }); + + tensor + } + + pub fn clamp_max(tensor: SharedArray, max: E) -> SharedArray { + let mut tensor = dispatch_binary_scalar_simd!( + E, + VecMin, + tensor, + max.elem(), + u8, + i8, + u16, + i16, + u32, + i32, + f32, + u64, + i64, + f64 + ); + + tensor.mapv_inplace(|x| match x > max { + true => max, + false => x, + }); + + tensor + } + + pub fn clamp(tensor: SharedArray, min: E, max: E) -> SharedArray { + let mut tensor = dispatch_binary_scalar_simd!( + E, + VecClamp, + tensor, + (min.elem(), max.elem()), + u8, + i8, + u16, + i16, + u32, + i32, + f32, + u64, + i64, + f64 + ); + + tensor.mapv_inplace(|x| match x < min { + true => min, + false => match x > max { + true => max, + false => x, + }, + }); + + tensor + } + + pub(crate) fn greater(lhs: SharedArray, rhs: SharedArray) -> SharedArray { + let (lhs, rhs) = dispatch_cmp_simd!( + E, VecGreater, lhs, rhs, u8, i8, u16, i16, u32, f32, i32, u64, i64, f64 + ); + + // Use the helper to broadcast both arrays to a common shape + let (lhs_broadcast, rhs_broadcast) = broadcast_for_comparison(&lhs, &rhs); + // Now we can safely zip and compare + Zip::from(&lhs_broadcast) + .and(&rhs_broadcast) + .map_collect(|&lhs, &rhs| lhs > rhs) + .into_shared() + } + + pub(crate) fn greater_elem(lhs: SharedArray, rhs: E) -> SharedArray { + let lhs = dispatch_cmp_scalar_simd!( + E, + VecGreater, + lhs, + rhs.elem(), + u8, + i8, + u16, + i16, + u32, + f32, + i32, + u64, + i64, + f64 + ); + + lhs.mapv(|a| a > rhs).into_shared() + } + + pub(crate) fn greater_equal(lhs: SharedArray, rhs: SharedArray) -> SharedArray { + let (lhs, rhs) = dispatch_cmp_simd!( + E, + VecGreaterEq, + lhs, + rhs, + u8, + i8, + u16, + i16, + u32, + f32, + i32, + u64, + i64, + f64 + ); + + // Use the helper to broadcast both arrays to a common shape + let (lhs_broadcast, rhs_broadcast) = broadcast_for_comparison(&lhs, &rhs); + // Now we can safely zip and compare + Zip::from(&lhs_broadcast) + .and(&rhs_broadcast) + .map_collect(|&lhs, &rhs| lhs >= rhs) + .into_shared() + } + + pub(crate) fn greater_equal_elem(lhs: SharedArray, rhs: E) -> SharedArray { + let lhs = dispatch_cmp_scalar_simd!( + E, + VecGreaterEq, + lhs, + rhs.elem(), + u8, + i8, + u16, + i16, + u32, + f32, + i32, + u64, + i64, + f64 + ); + + lhs.mapv(|a| a >= rhs).into_shared() + } + + pub(crate) fn lower_equal(lhs: SharedArray, rhs: SharedArray) -> SharedArray { + let (lhs, rhs) = dispatch_cmp_simd!( + E, VecLowerEq, lhs, rhs, u8, i8, u16, i16, u32, f32, i32, u64, i64, f64 + ); + + // Use the helper to broadcast both arrays to a common shape + let (lhs_broadcast, rhs_broadcast) = broadcast_for_comparison(&lhs, &rhs); + // Now we can safely zip and compare + Zip::from(&lhs_broadcast) + .and(&rhs_broadcast) + .map_collect(|&lhs, &rhs| lhs <= rhs) + .into_shared() + } + + pub(crate) fn lower_equal_elem(lhs: SharedArray, rhs: E) -> SharedArray { + let lhs = dispatch_cmp_scalar_simd!( + E, + VecLowerEq, + lhs, + rhs.elem(), + u8, + i8, + u16, + i16, + u32, + f32, + i32, + u64, + i64, + f64 + ); + + lhs.mapv(|a| a <= rhs).into_shared() + } + + pub(crate) fn lower(lhs: SharedArray, rhs: SharedArray) -> SharedArray { + let (lhs, rhs) = dispatch_cmp_simd!( + E, VecLower, lhs, rhs, u8, i8, u16, i16, u32, f32, i32, u64, i64, f64 + ); + + // Use the helper to broadcast both arrays to a common shape + let (lhs_broadcast, rhs_broadcast) = broadcast_for_comparison(&lhs, &rhs); + + // Now we can safely zip and compare + Zip::from(&lhs_broadcast) + .and(&rhs_broadcast) + .map_collect(|&lhs, &rhs| lhs < rhs) + .into_shared() + } + + pub(crate) fn lower_elem(lhs: SharedArray, rhs: E) -> SharedArray { + let lhs = dispatch_cmp_scalar_simd!( + E, + VecLower, + lhs, + rhs.elem(), + u8, + i8, + u16, + i16, + u32, + f32, + i32, + u64, + i64, + f64 + ); + + lhs.mapv(|a| a < rhs).into_shared() + } +} + +pub struct NdArrayBitOps(PhantomData); + +impl NdArrayBitOps { + pub(crate) fn bitand(lhs: SharedArray, rhs: SharedArray) -> SharedArray { + let (lhs, rhs) = + dispatch_binary_simd!(I, VecBitAnd, lhs, rhs, i8, u8, i16, u16, i32, u32, i64, u64); + + NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| { + (a.elem::() & (b.elem::())).elem() + }) + } + + pub(crate) fn bitand_scalar(lhs: SharedArray, rhs: I) -> SharedArray { + let lhs = dispatch_binary_scalar_simd!( + I, + VecBitAnd, + lhs, + rhs.elem(), + i8, + u8, + i16, + u16, + i32, + u32, + i64, + u64 + ); + + NdArrayMathOps::elementwise_op_scalar(lhs, |a: I| { + (a.elem::() & rhs.elem::()).elem() + }) + } + + pub(crate) fn bitor(lhs: SharedArray, rhs: SharedArray) -> SharedArray { + let (lhs, rhs) = + dispatch_binary_simd!(I, VecBitOr, lhs, rhs, i8, u8, i16, u16, i32, u32, i64, u64); + + NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| { + (a.elem::() | (b.elem::())).elem() + }) + } + + pub(crate) fn bitor_scalar(lhs: SharedArray, rhs: I) -> SharedArray { + let lhs = dispatch_binary_scalar_simd!( + I, + VecBitOr, + lhs, + rhs.elem(), + i8, + u8, + i16, + u16, + i32, + u32, + i64, + u64 + ); + + NdArrayMathOps::elementwise_op_scalar(lhs, |a: I| { + (a.elem::() | rhs.elem::()).elem() + }) + } + + pub(crate) fn bitxor(lhs: SharedArray, rhs: SharedArray) -> SharedArray { + let (lhs, rhs) = + dispatch_binary_simd!(I, VecBitXor, lhs, rhs, i8, u8, i16, u16, i32, u32, i64, u64); + + NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| { + (a.elem::() ^ (b.elem::())).elem() + }) + } + + pub(crate) fn bitxor_scalar(lhs: SharedArray, rhs: I) -> SharedArray { + let lhs = dispatch_binary_scalar_simd!( + I, + VecBitXor, + lhs, + rhs.elem(), + i8, + u8, + i16, + u16, + i32, + u32, + i64, + u64 + ); + + NdArrayMathOps::elementwise_op_scalar(lhs, |a: I| { + (a.elem::() ^ rhs.elem::()).elem() + }) + } + + pub(crate) fn bitnot(tensor: SharedArray) -> SharedArray { + let tensor = + dispatch_unary_simd!(I, VecBitNot, tensor, i8, u8, i16, u16, i32, u32, i64, u64); + + NdArrayMathOps::elementwise_op_scalar(tensor, |a: I| (!a.elem::()).elem()) + } +} + +pub struct NdArrayBoolOps; + +// Rust booleans are either `00000000` or `00000001`, so bitwise and/or is fine, but bitwise not would +// produce invalid values. +impl NdArrayBoolOps { + pub(crate) fn equal(lhs: SharedArray, rhs: SharedArray) -> SharedArray { + #[cfg(feature = "simd")] + let (lhs, rhs) = match try_cmp_simd::(lhs, rhs) { + Ok(out) => return out, + Err(args) => args, + }; + + // Use the helper to broadcast both arrays to a common shape + let (lhs_broadcast, rhs_broadcast) = broadcast_for_comparison(&lhs, &rhs); + // Now we can safely zip and compare + Zip::from(&lhs_broadcast) + .and(&rhs_broadcast) + .map_collect(|&lhs, &rhs| lhs == rhs) + .into_shared() + } + + pub(crate) fn equal_elem(lhs: SharedArray, rhs: bool) -> SharedArray { + #[cfg(feature = "simd")] + let lhs = match try_cmp_scalar_simd::(lhs, rhs.elem()) { + Ok(out) => return out, + Err(args) => args, + }; + + lhs.mapv(|a| a == rhs).into_shared() + } + + pub(crate) fn and(lhs: SharedArray, rhs: SharedArray) -> SharedArray { + #[cfg(feature = "simd")] + let (lhs, rhs) = match try_binary_simd::(lhs, rhs) { + Ok(out) => return out, + Err(args) => args, + }; + + // Use the helper to broadcast both arrays to a common shape + let (lhs_broadcast, rhs_broadcast) = broadcast_for_comparison(&lhs, &rhs); + // Now we can safely zip and compare + Zip::from(&lhs_broadcast) + .and(&rhs_broadcast) + .map_collect(|&lhs, &rhs| lhs && rhs) + .into_shared() + } + + pub(crate) fn or(lhs: SharedArray, rhs: SharedArray) -> SharedArray { + #[cfg(feature = "simd")] + let (lhs, rhs) = match try_binary_simd::(lhs, rhs) { + Ok(out) => return out, + Err(args) => args, + }; + + // Use the helper to broadcast both arrays to a common shape + let (lhs_broadcast, rhs_broadcast) = broadcast_for_comparison(&lhs, &rhs); + // Now we can safely zip and compare + Zip::from(&lhs_broadcast) + .and(&rhs_broadcast) + .map_collect(|&lhs, &rhs| lhs || rhs) + .into_shared() + } + + /// Any element is true - zero-copy for borrowed storage. + pub fn any_view(view: ArrayView<'_, bool, IxDyn>) -> bool { + view.iter().any(|&x| x) + } + + /// All elements are true - zero-copy for borrowed storage. + pub fn all_view(view: ArrayView<'_, bool, IxDyn>) -> bool { + view.iter().all(|&x| x) + } +} + +enum CmpType { + Min, + Max, +} + +fn arg( + tensor: SharedArray, + dim: usize, + cmp: CmpType, +) -> SharedArray { + arg_view(tensor.view(), dim, cmp) +} + +/// View-based argmax/argmin - zero-copy for borrowed storage. +fn arg_view( + view: ArrayView<'_, E, IxDyn>, + dim: usize, + cmp: CmpType, +) -> SharedArray { + let mut reshape = view.shape().to_vec(); + reshape[dim] = 1; + + let output = view.map_axis(Axis(dim), |arr| { + // Find the min/max value in the array, and return its index. + let (_e, idx) = arr.indexed_iter().fold((arr[0], 0usize), |acc, (idx, e)| { + let cmp = match cmp { + CmpType::Min => e < &acc.0, + CmpType::Max => e > &acc.0, + }; + + if cmp { (*e, idx) } else { acc } + }); + + (idx as i64).elem() + }); + + let output = output.to_shape(Dim(reshape.as_slice())).unwrap(); + + output.into_shared() +} + +#[cfg(test)] +mod tests { + use burn_backend::TensorData; + + use crate::NdArrayTensor; + + use super::*; + + #[test] + fn should_generate_row_major_layout_for_cat() { + let expected_shape: &[usize] = &[4, 6, 2]; + let expected_strides: &[isize] = &[12, 2, 1]; + let NdArrayTensor::I32(expected_storage) = NdArrayTensor::from_data(TensorData::from([ + [[1, 0], [2, 0], [3, 0], [4, 0], [5, 0], [6, 0]], + [[7, 0], [8, 0], [9, 0], [10, 0], [11, 0], [12, 0]], + [[13, 0], [14, 0], [15, 0], [16, 0], [17, 0], [18, 0]], + [[19, 0], [20, 0], [21, 0], [22, 0], [23, 0], [24, 0]], + ])) else { + panic!() + }; + let expected_array = expected_storage.into_shared(); + + let NdArrayTensor::I32(tensor_storage) = NdArrayTensor::from_data(TensorData::from([ + [1, 2, 3, 4, 5, 6], + [7, 8, 9, 10, 11, 12], + [13, 14, 15, 16, 17, 18], + [19, 20, 21, 22, 23, 24], + ])) else { + panic!() + }; + let tensor = tensor_storage.into_shared(); + + // unsqueeze dim on the outermost axis + let array = NdArrayOps::reshape(tensor, Shape::from([4, 6, 1])); + let NdArrayTensor::I32(zeros_storage) = + NdArrayTensor::from_data(TensorData::zeros::([4, 6, 1])) + else { + panic!() + }; + let zeros = zeros_storage.into_shared(); + // make `ndarray` concatenates array on the outermost axis + let array = NdArrayOps::cat([array, zeros].to_vec(), 2); + + assert!(array.is_standard_layout()); + assert_eq!(array.shape(), expected_shape); + assert_eq!(array.strides(), expected_strides); + assert_eq!( + array.into_iter().collect::>(), + expected_array.into_iter().collect::>(), + ); + } +} diff --git a/crates/burn/src/ops/bool_tensor.rs b/crates/burn/src/ops/bool_tensor.rs new file mode 100644 index 00000000..1d1f26d3 --- /dev/null +++ b/crates/burn/src/ops/bool_tensor.rs @@ -0,0 +1,241 @@ +// Language +use alloc::vec; +use alloc::vec::Vec; +use burn_backend::Scalar; +use burn_backend::{ElementConversion, TensorMetadata, tensor::FloatTensor}; +use burn_backend::{ + backend::ExecutionError, + ops::BoolTensorOps, + tensor::{BoolTensor, IntTensor}, +}; +use burn_std::{BoolDType, FloatDType, IntDType}; +use ndarray::IntoDimension; + +// Current crate +use crate::element::{FloatNdArrayElement, IntNdArrayElement, QuantElement}; +use crate::{NdArray, execute_with_int_dtype, tensor::NdArrayTensor}; +use crate::{ + NdArrayDevice, SharedArray, execute_with_float_out_dtype, execute_with_int_out_dtype, slice, +}; + +// Workspace crates +use burn_backend::{Shape, TensorData, backend::Backend}; + +use super::{NdArrayBoolOps, NdArrayOps}; + +impl BoolTensorOps + for NdArray +where + NdArrayTensor: From>, + NdArrayTensor: From>, +{ + fn bool_from_data(data: TensorData, _device: &NdArrayDevice) -> NdArrayTensor { + if !data.dtype.is_bool() { + unimplemented!("Unsupported dtype for `bool_from_data`") + } + NdArrayTensor::from_data(data) + } + + async fn bool_into_data(tensor: NdArrayTensor) -> Result { + Ok(tensor.into_data()) + } + + fn bool_to_device(tensor: NdArrayTensor, _device: &NdArrayDevice) -> NdArrayTensor { + tensor + } + + fn bool_reshape(tensor: NdArrayTensor, shape: Shape) -> NdArrayTensor { + NdArrayOps::reshape(tensor.bool(), shape).into() + } + + fn bool_slice(tensor: NdArrayTensor, slices: &[burn_backend::Slice]) -> NdArrayTensor { + slice!(tensor, slices) + } + + fn bool_into_int(tensor: NdArrayTensor, out_dtype: IntDType) -> NdArrayTensor { + // Use mapv directly instead of collecting to Vec and going through TensorData + execute_with_int_out_dtype!( + out_dtype, + I, + tensor.bool().mapv(|b| b.elem::()).into_shared().into() + ) + } + + fn bool_device(_tensor: &NdArrayTensor) -> as Backend>::Device { + NdArrayDevice::Cpu + } + + fn bool_empty( + shape: Shape, + _device: & as Backend>::Device, + dtype: BoolDType, + ) -> NdArrayTensor { + Self::bool_zeros(shape, _device, dtype) + } + + fn bool_zeros( + shape: Shape, + _device: & as Backend>::Device, + _dtype: BoolDType, + ) -> NdArrayTensor { + let values = vec![false; shape.num_elements()]; + NdArrayTensor::from_data(TensorData::new(values, shape)) + } + + fn bool_ones( + shape: Shape, + _device: & as Backend>::Device, + _dtype: BoolDType, + ) -> NdArrayTensor { + let values = vec![true; shape.num_elements()]; + NdArrayTensor::from_data(TensorData::new(values, shape)) + } + + fn bool_slice_assign( + tensor: NdArrayTensor, + slices: &[burn_backend::Slice], + value: NdArrayTensor, + ) -> NdArrayTensor { + NdArrayOps::slice_assign(tensor.bool(), slices, value.bool()).into() + } + + fn bool_cat(tensors: Vec, dim: usize) -> NdArrayTensor { + NdArrayOps::cat(tensors.into_iter().map(|it| it.bool()).collect(), dim).into() + } + + fn bool_equal(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { + NdArrayBoolOps::equal(lhs.bool(), rhs.bool()).into() + } + + fn bool_not(tensor: NdArrayTensor) -> NdArrayTensor { + tensor.bool().mapv(|a| !a).into_shared().into() + } + + fn bool_and(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { + NdArrayBoolOps::and(lhs.bool(), rhs.bool()).into() + } + + fn bool_or(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { + NdArrayBoolOps::or(lhs.bool(), rhs.bool()).into() + } + + fn bool_into_float(tensor: NdArrayTensor, out_dtype: FloatDType) -> FloatTensor { + execute_with_float_out_dtype!( + out_dtype, + E, + tensor.bool().mapv(|b| b.elem::()).into_shared().into() + ) + } + + fn bool_swap_dims(tensor: NdArrayTensor, dim1: usize, dim2: usize) -> NdArrayTensor { + NdArrayOps::swap_dims(tensor.bool(), dim1, dim2).into() + } + + fn bool_permute(tensor: NdArrayTensor, axes: &[usize]) -> NdArrayTensor { + tensor.bool().permuted_axes(axes.into_dimension()).into() + } + + fn bool_expand(tensor: NdArrayTensor, shape: Shape) -> NdArrayTensor { + NdArrayOps::expand(tensor.bool(), shape).into() + } + + fn bool_select(tensor: NdArrayTensor, dim: usize, indices: NdArrayTensor) -> NdArrayTensor { + execute_with_int_dtype!(indices, I, |indices: SharedArray| -> NdArrayTensor { + let tensor_bool = tensor.bool(); + let indices_vec: Vec = indices + .into_iter() + .map(|i| i.elem::() as usize) + .collect(); + + let selected = tensor_bool.select(ndarray::Axis(dim), &indices_vec); + selected.into_shared().into() + }) + } + + fn bool_select_or( + tensor: NdArrayTensor, + dim: usize, + indices: NdArrayTensor, + value: NdArrayTensor, + ) -> NdArrayTensor { + execute_with_int_dtype!(indices, I, |indices: SharedArray| -> NdArrayTensor { + let mut output_array = tensor.bool().into_owned(); + let value_bool = value.bool(); + + for (index_value, index) in indices.into_iter().enumerate() { + let index_usize = index.elem::() as usize; + let mut view = output_array.index_axis_mut(ndarray::Axis(dim), index_usize); + let value_slice = value_bool.index_axis(ndarray::Axis(dim), index_value); + // For boolean tensors, select_assign should use logical OR operation + view.zip_mut_with(&value_slice, |a, b| *a = *a || *b); + } + output_array.into_shared().into() + }) + } + + fn bool_flip(tensor: NdArrayTensor, axes: &[usize]) -> NdArrayTensor { + NdArrayOps::flip(tensor.bool(), axes).into() + } + + fn bool_unfold(tensor: NdArrayTensor, dim: usize, size: usize, step: usize) -> NdArrayTensor { + NdArrayOps::unfold(tensor.bool(), dim, size, step).into() + } + + fn bool_mask_where( + tensor: BoolTensor, + mask: BoolTensor, + value: BoolTensor, + ) -> BoolTensor { + NdArrayOps::mask_where(tensor.bool(), mask.bool(), value.bool()).into() + } + + fn bool_mask_fill( + tensor: BoolTensor, + mask: BoolTensor, + value: Scalar, + ) -> BoolTensor { + NdArrayOps::mask_fill(tensor.bool(), mask.bool(), value.elem()).into() + } + + fn bool_gather( + dim: usize, + tensor: BoolTensor, + indices: IntTensor, + ) -> BoolTensor { + execute_with_int_dtype!(indices, |indices| NdArrayOps::gather( + dim, + tensor.bool(), + indices + )) + } + + fn bool_scatter_or( + dim: usize, + tensor: BoolTensor, + indices: IntTensor, + value: BoolTensor, + ) -> BoolTensor { + execute_with_int_dtype!(indices, |indices| NdArrayOps::scatter( + dim, + tensor.bool(), + indices, + value.bool() + )) + } + + fn bool_equal_elem(lhs: BoolTensor, rhs: Scalar) -> BoolTensor { + NdArrayBoolOps::equal_elem(lhs.bool(), rhs.elem()).into() + } + + fn bool_any(tensor: BoolTensor) -> BoolTensor { + // Use view() for zero-copy on borrowed storage with short-circuit evaluation + let result = NdArrayBoolOps::any_view(tensor.bool().view()); + NdArrayTensor::from_data(TensorData::new(vec![result], Shape::new([1]))) + } + + fn bool_all(tensor: BoolTensor) -> BoolTensor { + // Use view() for zero-copy on borrowed storage with short-circuit evaluation + let result = NdArrayBoolOps::all_view(tensor.bool().view()); + NdArrayTensor::from_data(TensorData::new(vec![result], Shape::new([1]))) + } +} diff --git a/crates/burn/src/ops/conv.rs b/crates/burn/src/ops/conv.rs new file mode 100644 index 00000000..5fb2cad5 --- /dev/null +++ b/crates/burn/src/ops/conv.rs @@ -0,0 +1,574 @@ +use burn_backend::{ + ElementConversion, + ops::{ + ConvOptions, ConvTransposeOptions, + conv::{calculate_conv_output_size, calculate_conv_transpose_output_size}, + }, +}; +use ndarray::{ + Array3, Array4, Array5, ArrayView2, ArrayView3, ArrayViewMut2, ArrayViewMut3, Axis, Dim, s, +}; + +use crate::{ + NdArrayElement, SharedArray, iter_par, iter_range_par, + ops::padding::{apply_padding_4d, apply_padding_5d}, + run_par, + sharing::UnsafeSharedRef, + tensor::NdArrayTensor, +}; + +#[inline(always)] +fn conv2d_mad_inner( + mut output: ArrayViewMut2, + x: ArrayView2, + k: E, + k_xy: (usize, usize), + out_xy: (usize, usize), + stride: (usize, usize), + dilation: (usize, usize), +) { + let (kh, kw) = k_xy; + let (out_width, out_height) = out_xy; + let (stride_width, stride_height) = stride; + let (dilation_width, dilation_height) = dilation; + + for oh in 0..out_height { + // Construct a sub-slice view of the input row. + // This is done upfront so that rustc does not have to emit bounds checks + // in the hot loop below. + let ir = x + .row(oh * stride_height + kh * dilation_height) + .to_slice() + .unwrap(); + + // Ditto. Construct a sub-slice view of the output row, and explicitly specify + // the bounds upfront as 0..out_width so that rustc can make the assumption + // that all accesses are in-bounds in the below loop. + let mut or = output.row_mut(oh); + let or = &mut or.as_slice_mut().unwrap()[0..out_width]; + + #[allow(clippy::needless_range_loop)] + for ow in 0..out_width { + let iw = ow * stride_width + kw * dilation_width; + or[ow] += ir[iw] * k; + } + } +} + +#[inline(always)] +fn conv3d_mad_inner( + mut output: ArrayViewMut3, + x: ArrayView3, + k: E, + k_xyz: (usize, usize, usize), + out_xyz: (usize, usize, usize), + stride: (usize, usize, usize), + dilation: (usize, usize, usize), +) { + let (kd, kh, kw) = k_xyz; + let (out_width, out_height, out_depth) = out_xyz; + let (stride_width, stride_height, stride_depth) = stride; + let (dilation_width, dilation_height, dilation_depth) = dilation; + + for od in 0..out_depth { + let id = od * stride_depth + kd * dilation_depth; + + for oh in 0..out_height { + let ih = oh * stride_height + kh * dilation_height; + + // Construct a sub-slice view of the input row. + // This is done upfront so that rustc does not have to emit bounds checks + // in the hot loop below. + let ir = x.slice(s![id, ih, ..]).to_slice().unwrap(); + + // Ditto. Construct a sub-slice view of the output row, and explicitly specify + // the bounds upfront as 0..out_width so that rustc can make the assumption + // that all accesses are in-bounds in the below loop. + let or = &mut output + .slice_mut(s![od, oh, 0..out_width]) + .into_slice() + .unwrap()[0..out_width]; + + #[allow(clippy::needless_range_loop)] + for ow in 0..out_width { + let iw = ow * stride_width + kw * dilation_width; + or[ow] += ir[iw] * k; + } + } + } +} + +pub(crate) fn conv2d( + x: SharedArray, + weight: SharedArray, + bias: Option>, + options: ConvOptions<2>, +) -> SharedArray +where + NdArrayTensor: From>, +{ + let [dilation_height, dilation_width] = options.dilation; + let [padding_height, padding_width] = options.padding; + let [stride_height, stride_width] = options.stride; + let [batch_size, _in_channels, in_height, in_width] = x.shape().try_into().unwrap(); + let [out_channels, in_channels, kernel_height, kernel_width] = + weight.shape().try_into().unwrap(); + let channels_per_group = out_channels / options.groups; + + let out_height = calculate_conv_output_size( + kernel_height, + stride_height, + padding_height, + dilation_height, + in_height, + ); + let out_width = calculate_conv_output_size( + kernel_width, + stride_width, + padding_width, + dilation_width, + in_width, + ); + + let x = apply_padding_4d::(x, options.padding, 0i32.elem()); + + // Convert inputs from dynamic indexes to static to improve perf. + let x = x.into_dimensionality::().unwrap(); + let weights = weight.into_dimensionality::().unwrap(); + + let mut output = Array3::zeros(Dim([batch_size * out_channels, out_height, out_width])); + + run_par!(|| { + iter_par!(output.axis_iter_mut(Axis(0))) + .enumerate() + .for_each( + #[inline(never)] + |(k, mut output)| { + let b = k / out_channels; + let oc = k % out_channels; + let g = oc / channels_per_group; + + for ic in (in_channels * g)..(in_channels * (g + 1)) { + let weight_ic = ic - (g * in_channels); + + let x = x.slice(s![b, ic, .., ..]); + let k = weights.slice(s![oc, weight_ic, .., ..]); + + for kh in 0..kernel_height { + for kw in 0..kernel_width { + let k = k[[kh, kw]]; + + // NOTE: This function call is duplicated twice so that the compiler can perform auto-vectorization + // in the case that the stride/dilation is 1. + #[allow(clippy::if_same_then_else)] + if (1, 1, 1, 1) + == ( + stride_width, + stride_height, + dilation_width, + dilation_height, + ) + { + conv2d_mad_inner( + output.view_mut(), + x.view(), + k, + (kh, kw), + (out_width, out_height), + (stride_width, stride_height), + (dilation_width, dilation_height), + ); + } else { + conv2d_mad_inner( + output.view_mut(), + x.view(), + k, + (kh, kw), + (out_width, out_height), + (stride_width, stride_height), + (dilation_width, dilation_height), + ); + } + } + } + } + + if let Some(bias) = &bias { + let bias = bias[oc]; + + for oh in 0..out_height { + // Get a mutable slice reference to the row we're looping over. + // We explicitly define the bounds to 0..out_width so that rustc can make + // the assumption that all accesses are in-bounds. + let mut or = output.row_mut(oh); + let or = &mut or.as_slice_mut().unwrap()[0..out_width]; + + #[allow(clippy::needless_range_loop)] + for ow in 0..out_width { + or[ow] += bias; + } + } + } + }, + ); + }); + + output + .to_shape([batch_size, out_channels, out_height, out_width]) + .unwrap() + .into_dyn() + .into_shared() +} + +pub(crate) fn conv_transpose2d( + x: SharedArray, + weight: SharedArray, + bias: Option>, + options: ConvTransposeOptions<2>, +) -> SharedArray { + let [dilation_height, dilation_width] = options.dilation; + let [padding_height, padding_width] = options.padding; + let [stride_height, stride_width] = options.stride; + let [out_padding_height, out_padding_width] = options.padding_out; + let [batch_size, _in_channels, in_height, in_width] = x.shape().try_into().unwrap(); + let [in_channels, out_channels, kernel_height, kernel_width] = + weight.shape().try_into().unwrap(); + + let out_height = calculate_conv_transpose_output_size( + kernel_height, + stride_height, + padding_height, + out_padding_height, + dilation_height, + in_height, + ); + let out_width = calculate_conv_transpose_output_size( + kernel_width, + stride_width, + padding_width, + out_padding_width, + dilation_width, + in_width, + ); + + let x = x; + let mut output = Array4::zeros(Dim([ + batch_size, + out_channels * options.groups, + out_height, + out_width, + ])); + + let unsafe_shared_out = UnsafeSharedRef::new(&mut output); + + run_par!(|| { + iter_range_par!(0, batch_size * out_channels * options.groups).for_each(|k| unsafe { + let b = k / (out_channels * options.groups); + let oc = k % out_channels; + let g = (k / out_channels) % options.groups; + + let output = unsafe_shared_out.get(); + + let oc_out = oc + (out_channels * g); + let ic_start = g * (in_channels / options.groups); + let ic_end = ic_start + in_channels / options.groups; + + for ic in ic_start..ic_end { + for ih in 0..in_height { + for iw in 0..in_width { + for kh in 0..kernel_height { + for kw in 0..kernel_width { + let oh = ih * stride_height + kh * dilation_height; + let ow = iw * stride_width + kw * dilation_width; + + if oh >= out_height + padding_height + || ow >= out_width + padding_width + || oh < padding_height + || ow < padding_width + { + continue; + } + + let oh = oh - padding_height; + let ow = ow - padding_width; + + output[[b, oc_out, oh, ow]] += + x[[b, ic, ih, iw]] * weight[[ic, oc, kh, kw]]; + } + } + } + } + } + + if let Some(bias) = &bias { + for oh in 0..out_height { + for ow in 0..out_width { + output[[b, oc_out, oh, ow]] += bias[oc_out]; + } + } + } + }); + }); + + output.into_dyn().into_shared() +} + +pub(crate) fn conv3d( + x: SharedArray, + weight: SharedArray, + bias: Option>, + options: ConvOptions<3>, +) -> SharedArray +where + NdArrayTensor: From>, +{ + let [dilation_depth, dilation_height, dilation_width] = options.dilation; + let [padding_depth, padding_height, padding_width] = options.padding; + let [stride_depth, stride_height, stride_width] = options.stride; + let [batch_size, _in_channels, in_depth, in_height, in_width] = x.shape().try_into().unwrap(); + let [ + out_channels, + in_channels, + kernel_depth, + kernel_height, + kernel_width, + ] = weight.shape().try_into().unwrap(); + let out_c_per_group = out_channels / options.groups; + + let out_depth = calculate_conv_output_size( + kernel_depth, + stride_depth, + padding_depth, + dilation_depth, + in_depth, + ); + let out_height = calculate_conv_output_size( + kernel_height, + stride_height, + padding_height, + dilation_height, + in_height, + ); + let out_width = calculate_conv_output_size( + kernel_width, + stride_width, + padding_width, + dilation_width, + in_width, + ); + + let x = apply_padding_5d::(x, options.padding, 0i32.elem()); + + // Convert inputs from dynamic indexes to static to improve perf. + let x = x.into_dimensionality::().unwrap(); + let weights = weight.into_dimensionality::().unwrap(); + + let mut output = Array4::zeros(Dim([ + batch_size * out_channels, + out_depth, + out_height, + out_width, + ])); + + run_par!(|| { + iter_par!(output.axis_iter_mut(Axis(0))) + .enumerate() + .for_each( + #[inline(never)] + |(k, mut output)| { + let b = k / out_channels; + let oc = k % out_channels; + let g = oc / out_c_per_group; + + for ic in (in_channels * g)..(in_channels * (g + 1)) { + let weight_ic = ic - (g * in_channels); + + let x = x.slice(s![b, ic, .., .., ..]); + let k = weights.slice(s![oc, weight_ic, .., .., ..]); + + for kd in 0..kernel_depth { + for kh in 0..kernel_height { + for kw in 0..kernel_width { + let k = k[[kd, kh, kw]]; + + // NOTE: This function call is duplicated twice so that the compiler can perform auto-vectorization + // in the case that the stride/dilation is 1. + #[allow(clippy::if_same_then_else)] + if (1, 1, 1, 1, 1, 1) + == ( + stride_width, + stride_height, + stride_depth, + dilation_width, + dilation_height, + dilation_depth, + ) + { + conv3d_mad_inner( + output.view_mut(), + x.view(), + k, + (kd, kh, kw), + (out_width, out_height, out_depth), + (stride_width, stride_height, stride_depth), + (dilation_width, dilation_height, dilation_depth), + ); + } else { + conv3d_mad_inner( + output.view_mut(), + x.view(), + k, + (kd, kh, kw), + (out_width, out_height, out_depth), + (stride_width, stride_height, stride_depth), + (dilation_width, dilation_height, dilation_depth), + ); + } + } + } + } + } + + if let Some(bias) = &bias { + let bias = bias[oc]; + + // Get a mutable iterator to the row we're looping over. + let orows = output.rows_mut(); + for mut or in orows { + // We explicitly define the bounds to 0..out_width so that rustc can make + // the assumption that all accesses are in-bounds. + let or = &mut or.as_slice_mut().unwrap()[0..out_width]; + + #[allow(clippy::needless_range_loop)] + for ow in 0..out_width { + or[ow] += bias; + } + } + } + }, + ); + }); + + output + .to_shape([batch_size, out_channels, out_depth, out_height, out_width]) + .unwrap() + .into_dyn() + .into_shared() +} + +pub(crate) fn conv_transpose3d( + x: SharedArray, + weight: SharedArray, + bias: Option>, + options: ConvTransposeOptions<3>, +) -> SharedArray { + let [dilation_depth, dilation_height, dilation_width] = options.dilation; + let [padding_depth, padding_height, padding_width] = options.padding; + let [stride_depth, stride_height, stride_width] = options.stride; + let [out_padding_depth, out_padding_height, out_padding_width] = options.padding_out; + let [batch_size, _in_channels, in_depth, in_height, in_width] = x.shape().try_into().unwrap(); + let [ + in_channels, + out_channels, + kernel_depth, + kernel_height, + kernel_width, + ] = weight.shape().try_into().unwrap(); + + let out_depth = calculate_conv_transpose_output_size( + kernel_depth, + stride_depth, + padding_depth, + out_padding_depth, + dilation_depth, + in_depth, + ); + let out_height = calculate_conv_transpose_output_size( + kernel_height, + stride_height, + padding_height, + out_padding_height, + dilation_height, + in_height, + ); + let out_width = calculate_conv_transpose_output_size( + kernel_width, + stride_width, + padding_width, + out_padding_width, + dilation_width, + in_width, + ); + + let x = x; + let mut output = Array5::zeros(Dim([ + batch_size, + out_channels * options.groups, + out_depth, + out_height, + out_width, + ])); + + let unsafe_shared_out = UnsafeSharedRef::new(&mut output); + + run_par!(|| { + iter_range_par!(0, batch_size * out_channels * options.groups).for_each(|k| unsafe { + let b = k / (out_channels * options.groups); + let oc = k % out_channels; + let g = (k / out_channels) % options.groups; + + let output = unsafe_shared_out.get(); + + let oc_out = oc + (out_channels * g); + let ic_start = g * (in_channels / options.groups); + let ic_end = ic_start + in_channels / options.groups; + + for ic in ic_start..ic_end { + for id in 0..in_depth { + for ih in 0..in_height { + for iw in 0..in_width { + for kd in 0..kernel_depth { + for kh in 0..kernel_height { + for kw in 0..kernel_width { + let od = id * stride_depth + kd * dilation_depth; + let oh = ih * stride_height + kh * dilation_height; + let ow = iw * stride_width + kw * dilation_width; + + if od >= out_depth + padding_depth + || oh >= out_height + padding_height + || ow >= out_width + padding_width + || od < padding_depth + || oh < padding_height + || ow < padding_width + { + continue; + } + + let od = od - padding_depth; + let oh = oh - padding_height; + let ow = ow - padding_width; + + output[[b, oc_out, od, oh, ow]] += + x[[b, ic, id, ih, iw]] * weight[[ic, oc, kd, kh, kw]]; + } + } + } + } + } + } + } + + if let Some(bias) = &bias { + for od in 0..out_depth { + for oh in 0..out_height { + for ow in 0..out_width { + output[[b, oc_out, od, oh, ow]] += bias[oc_out]; + } + } + } + } + }); + }); + + output.into_dyn().into_shared() +} diff --git a/crates/burn/src/ops/deform_conv.rs b/crates/burn/src/ops/deform_conv.rs new file mode 100644 index 00000000..390010b9 --- /dev/null +++ b/crates/burn/src/ops/deform_conv.rs @@ -0,0 +1,662 @@ +use burn_backend::ops::{DeformConvOptions, conv::calculate_conv_output_size}; +use core::ops::AddAssign; +use ndarray::{ + Array2, Array4, ArrayView2, ArrayView3, ArrayView4, ArrayView6, ArrayViewMut2, Axis, Dim, Ix4, + Zip, s, +}; + +#[cfg(not(feature = "std"))] +#[allow(unused_imports)] +use num_traits::Float; + +use crate::{FloatNdArrayElement, NdArrayTensor, ShapeOps, SharedArray, iter_par, run_par}; + +use super::matmul::matmul; + +#[inline(always)] +#[allow(clippy::too_many_arguments)] +fn deform_im2col_kernel( + out_y: usize, + out_x: usize, + input: ArrayView2, + offset: ArrayView3, + mask: Option>, + mut columns: ArrayViewMut2, + args: DeformConvOptions<2>, + (kernel_h, kernel_w): (usize, usize), +) { + // position shape: [in_channels, batch_size, out_h, out_w] + // columns shape: [[in_channels, kernel_h, kernel_w], [batch_size, out_h, out_w]] + + let (height, width) = input.dim(); + + for kernel_y in 0..kernel_h { + for kernel_x in 0..kernel_w { + let mask_value = mask + .map(|it| it[[kernel_y, kernel_x]]) + .unwrap_or_else(|| F::from_elem(1.0)); + + let offset = offset.slice(s![kernel_y, kernel_x, ..]); + let y = F::from_elem(out_y * args.stride[0] + kernel_y * args.dilation[0]) + - F::from_elem(args.padding[0]) + + offset[0]; + let x = F::from_elem(out_x * args.stride[1] + kernel_x * args.dilation[1]) + - F::from_elem(args.padding[1]) + + offset[1]; + + let interpolated = bilinear_interpolate(input, height, width, y, x); + + columns[[kernel_y, kernel_x]] = mask_value * interpolated; + } + } +} + +fn bilinear_interpolate( + input: ArrayView2, + height: usize, + width: usize, + y: F, + x: F, +) -> F { + // To simplify code + let y = y.to_f32(); + let x = x.to_f32(); + + let mut result = F::from_elem(0.0); + if y > -1.0 && height as f32 > y && x > -1.0 && width as f32 > x { + let y_low = f32::floor(y); + let x_low = f32::floor(x); + let y_high = (y_low + 1.) as usize; + let x_high = (x_low + 1.) as usize; + + let zero = F::from_elem(0.0); + let v1: F = if y_low >= 0. && x_low >= 0. { + input[[y_low as usize, x_low as usize]] + } else { + zero + }; + let v2: F = if y_low >= 0. && x_high < width { + input[[y_low as usize, x_high]] + } else { + zero + }; + let v3: F = if y_high < height && x_low >= 0. { + input[[y_high, x_low as usize]] + } else { + zero + }; + let v4: F = if y_high < height && x_high < width { + input[[y_high, x_high]] + } else { + zero + }; + + let l_y = y - y_low; + let l_x = x - x_low; + let h_y = 1.0 - l_y; + let h_x = 1.0 - l_x; + + let w1 = F::from_elem(h_y * h_x); + let w2 = F::from_elem(h_y * l_x); + let w3 = F::from_elem(l_y * h_x); + let w4 = F::from_elem(l_y * l_x); + + result = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4; + } + result +} + +pub(crate) fn deform_conv2d( + input: SharedArray, + offset: SharedArray, + weight: SharedArray, + mask: Option>, + bias: Option>, + args: DeformConvOptions<2>, +) -> SharedArray +where + NdArrayTensor: From>, +{ + let [batch_size, _, in_height, in_width] = input.shape().dims(); + let [out_channels, _, kernel_h, kernel_w] = weight.shape().dims(); + let groups = args.weight_groups; + + let weight = weight.as_standard_layout(); + + let out_h = calculate_conv_output_size( + kernel_h, + args.stride[0], + args.padding[0], + args.dilation[0], + in_height, + ); + let out_w = calculate_conv_output_size( + kernel_w, + args.stride[1], + args.padding[1], + args.dilation[1], + in_width, + ); + let out_dims = (out_h, out_w); + + let input = input.into_dimensionality::().unwrap(); + let offset = offset.into_dimensionality::().unwrap(); + let mask = mask.as_ref().map(|it| { + it.to_shape(( + batch_size, + args.offset_groups, + kernel_h, + kernel_w, + out_h, + out_w, + )) + .unwrap() + }); + + let columns = deform_im2col( + input.view(), + offset.view(), + mask.as_ref().map(|it| it.view()), + args, + out_dims, + (kernel_h, kernel_w), + ); + + let (col_size_0, col_size_1) = columns.dim(); + let col_size_0 = col_size_0 / groups; + let out_c_per_group = out_channels / groups; + + let weight = weight + .to_shape((groups, out_c_per_group, col_size_0)) + .unwrap(); + let columns = columns.to_shape((groups, col_size_0, col_size_1)).unwrap(); + let out = matmul( + weight.to_owned().into_dyn().into_shared(), + columns.to_owned().into_dyn().into_shared(), + ); + + let mut out = out + .into_shape_with_order((out_channels, batch_size, out_h, out_w)) + .unwrap(); + out.swap_axes(0, 1); + + if let Some(bias) = bias { + let bias = bias.to_shape((1, out_channels, 1, 1)).unwrap(); + out.add_assign(&bias); + } + + out.into_dyn().into_shared() +} + +pub(crate) fn deform_im2col( + input: ArrayView4, + offset: ArrayView4, + mask: Option>, + args: DeformConvOptions<2>, + out_dims: (usize, usize), + kernel_dims: (usize, usize), +) -> Array2 { + let (batch_size, in_channels, _, _) = input.dim(); + let (kernel_h, kernel_w) = kernel_dims; + let (out_h, out_w) = out_dims; + let channels_per_offset_group = in_channels / args.offset_groups; + + let mut columns = Array4::zeros(Dim([ + in_channels, + kernel_h, + kernel_w, + batch_size * out_h * out_w, + ])); + + let groups = args.offset_groups; + + run_par!(|| { + iter_par!(columns.axis_iter_mut(Axis(3))) + .enumerate() + .for_each(|(index, mut columns)| { + let out_x = index % out_w; + let out_y = (index / out_w) % out_h; + let batch = (index / (out_w * out_h)) % batch_size; + let offset = offset.slice(s![batch, .., out_y, out_x]); + let offset = offset.to_shape((groups, kernel_h, kernel_w, 2)).unwrap(); + let mask = mask + .as_ref() + .map(|it| it.slice(s![batch, .., .., .., out_y, out_x])); + columns + .axis_iter_mut(Axis(0)) + .enumerate() + .for_each(|(in_channel, mut columns)| { + let group_index = in_channel / channels_per_offset_group; + deform_im2col_kernel( + out_y, + out_x, + input.slice(s![batch, in_channel, .., ..]), + offset.slice(s![group_index, .., .., ..]), + mask.as_ref().map(|it| it.slice(s![group_index, .., ..])), + columns.view_mut(), + args.clone(), + kernel_dims, + ); + }); + }); + }); + + columns + // Columns is created here, so we know it's contiguous + .into_shape_with_order(( + in_channels * kernel_h * kernel_w, + batch_size * out_h * out_w, + )) + .unwrap() +} + +pub mod backward { + #[cfg(target_has_atomic = "32")] + use core::sync::atomic::Ordering; + + use atomic_float::AtomicF32; + use ndarray::{Array1, Array5, ArrayView4, ArrayView6, Ix4}; + + use super::*; + + pub(crate) type DeformConv2dBackward = ( + SharedArray, + SharedArray, + SharedArray, + Option>, + Option>, + ); + + /// Calculate the [deformable 2D convolution](crate::ops::ModuleOps::deform_conv2d) backward pass using convolutions. + pub(crate) fn deform_conv2d_backward( + input: SharedArray, + offset: SharedArray, + weight: SharedArray, + mask: Option>, + bias: Option>, + out_grad: SharedArray, + args: DeformConvOptions<2>, + ) -> DeformConv2dBackward { + let [batch_size, out_channels, out_h, out_w] = out_grad.shape().dims(); + let [_, _, kernel_h, kernel_w] = weight.shape().dims(); + let groups = args.weight_groups; + let out_c_per_group = out_channels / groups; + let col_shape_1 = batch_size * out_h * out_w; + let mut out_grad = out_grad.into_dimensionality::().unwrap(); + + let gradient_bias = bias.map(|_| { + let out_grad = out_grad + .clone() + .sum_axis(Axis(0)) + .sum_axis(Axis(1)) + .sum_axis(Axis(1)); + + out_grad.into_dyn().into_shared() + }); + + out_grad.swap_axes(0, 1); + let out_grad = out_grad + .to_shape((groups, out_c_per_group, col_shape_1)) + .unwrap(); + + let input = input.into_dimensionality::().unwrap(); + let offset = offset.into_dimensionality::().unwrap(); + let mask = mask.map(|it| { + it.into_shape_with_order(( + batch_size, + args.offset_groups, + kernel_h, + kernel_w, + out_h, + out_w, + )) + .unwrap() + }); + + let (input_gradient, offset_gradient, mask_gradient) = backward_gradient_inputs( + input.view(), + weight, + offset.view(), + mask.as_ref().map(|it| it.view()), + out_grad.view(), + &args, + (kernel_h, kernel_w), + ); + + let weight_grad = compute_weight_grad( + input.view(), + offset.view(), + mask.as_ref().map(|it| it.view()), + out_grad.view(), + args, + (kernel_h, kernel_w), + (out_h, out_w), + ); + + ( + input_gradient, + offset_gradient, + weight_grad, + mask_gradient, + gradient_bias, + ) + } + + fn compute_weight_grad( + input: ArrayView4, + offset: ArrayView4, + mask: Option>, + out_grad: ArrayView3, + options: DeformConvOptions<2>, + kernel_dims: (usize, usize), + out_dims: (usize, usize), + ) -> SharedArray { + let in_channels = input.dim().1; + let (groups, out_c_per_group, _) = out_grad.dim(); + let (kernel_h, kernel_w) = kernel_dims; + + let in_c_per_group = in_channels / groups; + + let columns = deform_im2col(input, offset, mask, options, out_dims, kernel_dims); + let (col_size_0, col_size_1) = columns.dim(); + let col_size_0 = col_size_0 / groups; + + let mut columns = columns.to_shape((groups, col_size_0, col_size_1)).unwrap(); + columns.swap_axes(1, 2); + + let grad_weight = matmul( + out_grad.to_owned().into_dyn().into_shared(), + columns.to_owned().into_dyn().into_shared(), + ); + + let grad_weight = grad_weight + .into_shape_with_order((out_c_per_group * groups, in_c_per_group, kernel_h, kernel_w)) + .unwrap(); + grad_weight.into_dyn().into_shared() + } + + type InputGradients = (SharedArray, SharedArray, Option>); + + fn backward_gradient_inputs( + image: ArrayView4, + weight: SharedArray, + offset: ArrayView4, + mask: Option>, + out_grad: ArrayView3, + args: &DeformConvOptions<2>, + kernel_dims: (usize, usize), + ) -> InputGradients { + let input_shape = image.dim(); + let in_channels = input_shape.1; + let [out_channels, in_c_per_group, kernel_h, kernel_w] = weight.shape().dims(); + let (batch_size, _, out_h, out_w) = offset.dim(); + + let groups = args.weight_groups; + let out_c_per_group = out_channels / groups; + + let col_shape_0 = in_c_per_group * kernel_h * kernel_w; + + let mut weight = weight + .to_shape((groups, out_c_per_group, col_shape_0)) + .unwrap(); + weight.swap_axes(1, 2); + let columns = matmul( + weight.to_owned().into_dyn().into_shared(), + out_grad.to_owned().into_dyn().into_shared(), + ); + + let columns = columns + .to_shape((in_channels, kernel_h, kernel_w, batch_size, out_h, out_w)) + .unwrap(); + + let (offset_gradient, mask_gradient) = compute_offset_and_mask_gradient( + columns.view(), + image.view(), + offset, + mask, + args, + kernel_dims, + ); + + let input_gradient = + compute_input_grad(columns.view(), offset, mask, args, kernel_dims, input_shape); + + (input_gradient, offset_gradient, mask_gradient) + } + + fn compute_offset_and_mask_gradient( + columns: ArrayView6, + image: ArrayView4, + offset: ArrayView4, + mask: Option>, + args: &DeformConvOptions<2>, + kernel_dims: (usize, usize), + ) -> (SharedArray, Option>) { + let (kernel_h, kernel_w) = kernel_dims; + let (_, in_channels, height, width) = image.dim(); + let (batch_size, offset_channels, out_h, out_w) = offset.dim(); + let offs_groups = args.offset_groups; + let channels_per_offset_group = in_channels / args.offset_groups; + + let mut grad_offset = Array5::zeros(( + offs_groups, + kernel_h, + kernel_w, + 2, + batch_size * out_h * out_w, + )); + let mut grad_mask = + Array4::zeros((offs_groups, kernel_h, kernel_w, batch_size * out_h * out_w)); + + grad_mask + .axis_iter_mut(Axis(3)) + .zip(grad_offset.axis_iter_mut(Axis(4))) + .enumerate() + .for_each(|(index, (mut grad_mask, mut grad_offset))| { + let out_x = index % out_w; + let out_y = (index / out_w) % out_h; + let batch = index / (out_w * out_h); + let offset = offset.slice(s![batch, .., out_y, out_x]); + let offset = offset + .to_shape((offs_groups, kernel_h, kernel_w, 2)) + .unwrap(); + let mask: Option> = mask + .as_ref() + .map(|mask| mask.slice(s![batch, .., .., .., out_y, out_x])); + let columns = columns.slice(s![.., .., .., batch, out_y, out_x]); + let image = image.slice(s![batch, .., .., ..]); + + for ((group, kernel_y, kernel_x), grad_mask) in grad_mask.indexed_iter_mut() { + let grad_mask: &mut F = grad_mask; + let mut grad_offset = grad_offset.slice_mut(s![group, kernel_y, kernel_x, ..]); + let offset = offset.slice(s![group, kernel_y, kernel_x, ..]); + let mask = mask.map(|it| it[[group, kernel_y, kernel_x]]); + let columns = columns.slice(s![.., kernel_y, kernel_x]); + let group_offset = group * channels_per_offset_group; + let image = image.slice(s![group_offset.., .., ..]); + let y = F::from_elem(out_y * args.stride[0] + kernel_y * args.dilation[0]) + - F::from_elem(args.padding[0]) + + offset[0]; + let x = F::from_elem(out_x * args.stride[1] + kernel_x * args.dilation[1]) + - F::from_elem(args.padding[1]) + + offset[1]; + for (i, grad_offset) in grad_offset.iter_mut().enumerate() { + let is_y_direction = i % 2 == 0; + let use_mask = mask.is_some(); + + for channel in 0..channels_per_offset_group { + let mask = mask.unwrap_or_else(|| F::one()); + let image = image.index_axis(Axis(0), channel); + let weight = + get_coordinate_weight(image, height, width, y, x, is_y_direction); + *grad_offset += mask * weight * columns[channel]; + if use_mask && is_y_direction { + *grad_mask += columns[channel] + * bilinear_interpolate(image, height, width, y, x); + } + } + } + } + }); + + let mask_gradient = mask.map(|_| { + let mut grad_mask = grad_mask + .into_shape_with_order((offset_channels / 2, batch_size, out_h, out_w)) + .unwrap(); + grad_mask.swap_axes(0, 1); + grad_mask.into_dyn().into_shared() + }); + let mut grad_offset = grad_offset + .into_shape_with_order((offset_channels, batch_size, out_h, out_w)) + .unwrap(); + grad_offset.swap_axes(0, 1); + let offset_gradient = grad_offset.into_dyn().into_shared(); + (offset_gradient, mask_gradient) + } + + fn get_coordinate_weight( + input: ArrayView2, + height: usize, + width: usize, + y: F, + x: F, + is_y_direction: bool, + ) -> F { + let y = y.to_f32(); + let x = x.to_f32(); + + let y_low = f32::floor(y); + let x_low = f32::floor(x); + let y_high = y_low + 1.; + let x_high = x_low + 1.; + + let valid_y_low = y_low >= 0. && y_low < height as f32; + let valid_y_high = y_high >= 0. && y_high < height as f32; + let valid_x_low = x_low >= 0. && x_low < width as f32; + let valid_x_high = x_high >= 0. && x_high < width as f32; + + let bottom_left = if valid_y_low && valid_x_low { + input[[y_low as usize, x_low as usize]] + } else { + F::zero() + }; + let bottom_right = if valid_y_low && valid_x_high { + input[[y_low as usize, x_high as usize]] + } else { + F::zero() + }; + let top_left = if valid_y_high && valid_x_low { + input[[y_high as usize, x_low as usize]] + } else { + F::zero() + }; + let top_right = if valid_y_high && valid_x_high { + input[[y_high as usize, x_high as usize]] + } else { + F::zero() + }; + + if is_y_direction { + let delta_x = F::from_elem(x - x_low); + delta_x * (top_right - bottom_right) + (F::one() - delta_x) * (top_left - bottom_left) + } else { + let delta_y = F::from_elem(y - y_low); + delta_y * (top_right - top_left) + (F::one() - delta_y) * (bottom_right - bottom_left) + } + } + + fn compute_input_grad( + columns: ArrayView6, + offset: ArrayView4, + mask: Option>, + args: &DeformConvOptions<2>, + kernel_dims: (usize, usize), + input_shape: (usize, usize, usize, usize), + ) -> SharedArray { + let (batch_size, in_channels, height, width) = input_shape; + let (kernel_h, kernel_w) = kernel_dims; + let offs_groups = args.offset_groups; + let channels_per_offset_group = in_channels / offs_groups; + + let grad_in = + Array4::from_shape_simple_fn((batch_size, in_channels, height, width), || { + AtomicF32::new(0.0) + }); + + let compute_for_each = |(in_channel, kernel_y, kernel_x, batch, out_y, out_x), col: &F| { + let group = in_channel / channels_per_offset_group; + let offset = offset.slice(s![batch, .., out_y, out_x]); + let offset = offset + .to_shape((offs_groups, kernel_h, kernel_w, 2)) + .unwrap(); + let offset = offset.slice(s![group, kernel_y, kernel_x, ..]); + let offset = [offset[0], offset[1]]; + let mask = mask + .as_ref() + .map(|it| it[[batch, group, kernel_y, kernel_x, out_y, out_x]].to_f32()); + let y = F::from_elem(out_y * args.stride[0] + kernel_y * args.dilation[0]) + - F::from_elem(args.padding[0]) + + offset[0]; + let x = F::from_elem(out_x * args.stride[1] + kernel_x * args.dilation[1]) + - F::from_elem(args.padding[1]) + + offset[1]; + let grad_in = grad_in.slice(s![batch, in_channel, .., ..]); + deform_col2img_kernel(y.to_f32(), x.to_f32(), mask, col.to_f32(), grad_in); + }; + + // `for_each` expects a 2-tuple argument with `.into_par_iter()`, but 2 separate arguments otherwise + #[cfg(feature = "multi-threads")] + run_par!(|| { + iter_par!(Zip::indexed(columns)) + .for_each(|(args0, args1)| compute_for_each(args0, args1)) + }); + + #[cfg(not(feature = "multi-threads"))] + run_par!(|| { iter_par!(Zip::indexed(columns)).for_each(&compute_for_each) }); + + let grad_in: Array1 = grad_in + .into_iter() + .map(|it| F::from_elem(it.into_inner())) + .collect(); + let grad_in = grad_in + .into_shape_with_order((batch_size, in_channels, height, width)) + .unwrap(); + grad_in.into_dyn().into_shared() + } + + fn deform_col2img_kernel( + y: f32, + x: f32, + mask: Option, + col: f32, + grad_input: ArrayView2, + ) { + let (height, width) = grad_input.dim(); + let mask_value = mask.unwrap_or(1.0); + + for dy in -1..=1 { + for dx in -1..=1 { + let yp = f32::floor(y) + dy as f32; + let xp = f32::floor(x) + dx as f32; + + if yp >= 0.0 + && yp < height as f32 + && xp >= 0.0 + && xp < width as f32 + && f32::abs(y - yp) < 1.0 + && f32::abs(x - xp) < 1.0 + { + let weight = (1.0 - f32::abs(y - yp)) * (1.0 - f32::abs(x - xp)); + + #[cfg_attr(not(target_has_atomic = "32"), allow(unused))] + let value = mask_value * weight * col; + + #[cfg(target_has_atomic = "32")] + grad_input[[yp as usize, xp as usize]].fetch_add(value, Ordering::AcqRel); + #[cfg(not(target_has_atomic = "32"))] + panic!("Can't use deformable convolution backwards pass without atomics"); + } + } + } + } +} diff --git a/crates/burn/src/ops/grid_sample.rs b/crates/burn/src/ops/grid_sample.rs new file mode 100644 index 00000000..256c2fd8 --- /dev/null +++ b/crates/burn/src/ops/grid_sample.rs @@ -0,0 +1,214 @@ +use burn_backend::ElementConversion; +use burn_backend::ops::{GridSampleOptions, GridSamplePaddingMode, InterpolateMode}; +#[cfg(not(feature = "std"))] +#[allow(unused_imports)] +use num_traits::Float; + +use ndarray::Array4; + +use crate::SharedArray; +use crate::{FloatNdArrayElement, UnsafeSharedRef, iter_range_par, run_par}; + +/// Sample a tensor using grid-based sampling. +/// +/// # Arguments +/// +/// * `tensor` - The tensor being sampled from, must be contiguous with shape (N, C, H_in, W_in) +/// * `grid` - A tensor of locations, with shape (N, H_out, W_out, 2). Values are [-1, 1]. +/// A [x = -1, y = -1] means top-left, and [x = 1, y = 1] means bottom-right +/// * `options` - Grid sampling options (mode, padding_mode, align_corners) +/// +/// # Returns +/// +/// A tensor with shape (N, C, H_out, W_out) +pub(crate) fn grid_sample_2d( + tensor: SharedArray, + grid: SharedArray, + options: GridSampleOptions, +) -> SharedArray { + match options.mode { + InterpolateMode::Bilinear => (), + _ => todo!( + "grid_sample_2d with {:?} mode is not implemented", + options.mode + ), + } + + let tensor = tensor.into_dimensionality::().unwrap(); + let grid = grid.into_dimensionality::().unwrap(); + + let (batch_size, channels, height_in, width_in) = tensor.dim(); + let (b, height_out, width_out, d) = grid.dim(); + assert!(batch_size == b); + assert!(2 == d); + + let mut output = Array4::zeros((batch_size, channels, height_out, width_out)); + let unsafe_shared_out = UnsafeSharedRef::new(&mut output); + + let sample_count = batch_size * channels * height_out * width_out; + let strides = ( + channels * height_out * width_out, + height_out * width_out, + width_out, + ); + + let align = options.align_corners; + let pad_mode = options.padding_mode; + + run_par!(|| { + iter_range_par!(0, sample_count).for_each(|id| { + let (b, c, y, x) = ( + id / strides.0, + id % strides.0 / strides.1, + id % strides.1 / strides.2, + id % strides.2, + ); + + let sample_x = grid[(b, y, x, 0)].elem::(); + let sample_y = grid[(b, y, x, 1)].elem::(); + + // Convert normalized grid coordinates [-1, 1] to pixel coordinates + let (px, py) = if align { + // align_corners=true: x_pixel = (x_norm + 1) * (width - 1) / 2 + // Maps -1 to 0 and 1 to width - 1 + let px = (sample_x + 1.0) * ((width_in - 1) as f64) / 2.0; + let py = (sample_y + 1.0) * ((height_in - 1) as f64) / 2.0; + (px, py) + } else { + // align_corners=false: x_pixel = (x_norm + 1) * width / 2 - 0.5 + // Maps -1 to -0.5 and 1 to width - 0.5 + let px = (sample_x + 1.0) * (width_in as f64) / 2.0 - 0.5; + let py = (sample_y + 1.0) * (height_in as f64) / 2.0 - 0.5; + (px, py) + }; + + // Bilinear interpolation with the specified padding mode + let val = + bilinear_interpolate(&tensor, b, c, px, py, width_in, height_in, pad_mode, align); + + unsafe { + let output = unsafe_shared_out.get(); + output[(b, c, y, x)] = val.elem(); + } + }); + }); + + output.into_dyn().into_shared() +} + +/// Bilinear interpolation at a point with configurable padding mode. +#[allow(clippy::too_many_arguments)] +fn bilinear_interpolate( + source: &ndarray::ArrayBase>, + b: usize, + c: usize, + x: f64, + y: f64, + width: usize, + height: usize, + padding_mode: GridSamplePaddingMode, + align_corners: bool, +) -> f64 +where + E: FloatNdArrayElement, + S: ndarray::Data, +{ + // Handle inf/nan coordinates + if !x.is_finite() || !y.is_finite() { + return match padding_mode { + GridSamplePaddingMode::Zeros => 0.0, + GridSamplePaddingMode::Border => { + // Clamp to center of image for inf/nan + let cx = ((width - 1) as f64 / 2.0).clamp(0.0, (width - 1) as f64); + let cy = ((height - 1) as f64 / 2.0).clamp(0.0, (height - 1) as f64); + source[(b, c, cy as usize, cx as usize)].elem::() + } + GridSamplePaddingMode::Reflection => 0.0, // Simplified: treat as zeros for inf/nan + }; + } + + // Apply padding mode to get actual sampling coordinates + let (x, y) = match padding_mode { + GridSamplePaddingMode::Border => { + // Clamp coordinates to valid range [0, size-1] + let x = x.clamp(0.0, (width - 1) as f64); + let y = y.clamp(0.0, (height - 1) as f64); + (x, y) + } + GridSamplePaddingMode::Reflection => { + // Reflect coordinates at boundaries + let x = reflect_coordinate(x, width, align_corners); + let y = reflect_coordinate(y, height, align_corners); + (x, y) + } + GridSamplePaddingMode::Zeros => (x, y), // Keep as-is, handle out-of-bounds in read + }; + + // Get the four corner indices + let x0 = x.floor() as i64; + let y0 = y.floor() as i64; + let x1 = x0.saturating_add(1); + let y1 = y0.saturating_add(1); + + // Compute interpolation weights (fractional part) + let x_frac = x - x.floor(); + let y_frac = y - y.floor(); + + // Helper to read a value based on padding mode + let read_value = |xi: i64, yi: i64| -> f64 { + match padding_mode { + GridSamplePaddingMode::Zeros => { + // Return 0 for out-of-bounds + if xi >= 0 && xi < width as i64 && yi >= 0 && yi < height as i64 { + source[(b, c, yi as usize, xi as usize)].elem::() + } else { + 0.0 + } + } + GridSamplePaddingMode::Border | GridSamplePaddingMode::Reflection => { + // Coordinates should already be in valid range after clamping/reflection + let xi = xi.clamp(0, (width - 1) as i64) as usize; + let yi = yi.clamp(0, (height - 1) as i64) as usize; + source[(b, c, yi, xi)].elem::() + } + } + }; + + // Read the four corners + let v00 = read_value(x0, y0); + let v01 = read_value(x0, y1); + let v10 = read_value(x1, y0); + let v11 = read_value(x1, y1); + + // Bilinear interpolation weights + let w00 = (1.0 - x_frac) * (1.0 - y_frac); + let w01 = (1.0 - x_frac) * y_frac; + let w10 = x_frac * (1.0 - y_frac); + let w11 = x_frac * y_frac; + + v00 * w00 + v01 * w01 + v10 * w10 + v11 * w11 +} + +/// Reflect a coordinate at the boundaries using a triangle wave pattern. +/// +/// For align_corners=true: reflects within [0, size-1] +/// For align_corners=false: reflects within [-0.5, size-0.5] +fn reflect_coordinate(coord: f64, size: usize, align_corners: bool) -> f64 { + let size_f = size as f64; + let (min_val, max_val) = if align_corners { + (0.0, size_f - 1.0) + } else { + (-0.5, size_f - 0.5) + }; + + let span = max_val - min_val; + if span <= 0.0 { + return min_val; + } + + // Triangle wave formula: span - |((x mod 2*span) - span)| + let period = 2.0 * span; + let x = (coord - min_val).abs(); + let x_mod = x - (x / period).floor() * period; + span - (x_mod - span).abs() + min_val +} diff --git a/crates/burn/src/ops/int_tensor.rs b/crates/burn/src/ops/int_tensor.rs new file mode 100644 index 00000000..02710cdc --- /dev/null +++ b/crates/burn/src/ops/int_tensor.rs @@ -0,0 +1,509 @@ +// Language +use crate::rand::get_seeded_rng; +use alloc::vec::Vec; +use burn_backend::backend::ExecutionError; +use burn_backend::ops::IntTensorOps; +use burn_backend::tensor::{FloatTensor, IntTensor}; +use burn_backend::{Distribution, IntDType, Scalar, TensorMetadata}; + +use burn_backend::ElementConversion; +use burn_std::{BoolDType, FloatDType}; + +// Current crate +use crate::{ExpElement, NdArrayDevice, SEED, execute_with_int_out_dtype, slice}; +use crate::{NdArray, cast_to_dtype, execute_with_dtype, tensor::NdArrayTensor}; +use crate::{SharedArray, element::QuantElement}; +use crate::{cat_with_dtype, execute_with_float_out_dtype}; +use crate::{element::FloatNdArrayElement, ops::matmul::matmul}; +use crate::{element::IntNdArrayElement, execute_with_int_dtype}; + +// Workspace crates +use super::{NdArrayBitOps, NdArrayMathOps, NdArrayOps}; +use burn_backend::{DType, Shape, TensorData, backend::Backend}; + +impl IntTensorOps + for NdArray +where + NdArrayTensor: From>, + NdArrayTensor: From>, +{ + fn int_from_data(data: TensorData, _device: &NdArrayDevice) -> NdArrayTensor { + if data.dtype.is_int() || data.dtype.is_uint() { + NdArrayTensor::from_data(data) + } else { + unimplemented!("Unsupported dtype for `int_from_data`: {:?}", data.dtype) + } + } + + async fn int_into_data(tensor: NdArrayTensor) -> Result { + Ok(tensor.into_data()) + } + + fn int_to_device(tensor: NdArrayTensor, _device: &NdArrayDevice) -> NdArrayTensor { + tensor + } + + fn int_reshape(tensor: NdArrayTensor, shape: Shape) -> NdArrayTensor { + execute_with_int_dtype!(tensor, |array| NdArrayOps::reshape(array, shape)) + } + + fn int_slice(tensor: NdArrayTensor, slices: &[burn_backend::Slice]) -> NdArrayTensor { + slice!(tensor, slices) + } + + fn int_device(_tensor: &NdArrayTensor) -> as Backend>::Device { + NdArrayDevice::Cpu + } + + fn int_empty( + shape: Shape, + device: & as Backend>::Device, + dtype: IntDType, + ) -> NdArrayTensor { + Self::int_zeros(shape, device, dtype) + } + + fn int_matmul(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + execute_with_int_dtype!((lhs, rhs), matmul) + } + + fn int_mask_where( + tensor: NdArrayTensor, + mask: NdArrayTensor, + source: NdArrayTensor, + ) -> NdArrayTensor { + execute_with_int_dtype!((tensor, source), |tensor, source| { + NdArrayOps::mask_where(tensor, mask.bool(), source) + }) + } + + fn int_mask_fill(tensor: NdArrayTensor, mask: NdArrayTensor, value: Scalar) -> NdArrayTensor { + execute_with_int_dtype!(tensor, |array| NdArrayOps::mask_fill( + array, + mask.bool(), + value.elem() + )) + } + + fn int_slice_assign( + tensor: NdArrayTensor, + slices: &[burn_backend::Slice], + value: NdArrayTensor, + ) -> NdArrayTensor { + execute_with_int_dtype!((tensor, value), |tensor, value| NdArrayOps::slice_assign( + tensor, slices, value + )) + } + + fn int_cat(tensors: Vec, dim: usize) -> NdArrayTensor { + cat_with_dtype!(tensors, dim, [I64, I32, I16, I8, U64, U32, U16, U8]) + } + + fn int_equal(lhs: NdArrayTensor, rhs: NdArrayTensor, _out_dtype: BoolDType) -> NdArrayTensor { + execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::equal) + } + + fn int_equal_elem(lhs: NdArrayTensor, rhs: Scalar, _out_dtype: BoolDType) -> NdArrayTensor { + execute_with_int_dtype!(lhs, |array| NdArrayMathOps::equal_elem(array, rhs.elem())) + } + + fn int_greater(lhs: NdArrayTensor, rhs: NdArrayTensor, _out_dtype: BoolDType) -> NdArrayTensor { + execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::greater) + } + + fn int_greater_elem(lhs: NdArrayTensor, rhs: Scalar, _out_dtype: BoolDType) -> NdArrayTensor { + execute_with_int_dtype!(lhs, |array| NdArrayMathOps::greater_elem(array, rhs.elem())) + } + + fn int_greater_equal( + lhs: NdArrayTensor, + rhs: NdArrayTensor, + _out_dtype: BoolDType, + ) -> NdArrayTensor { + execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::greater_equal) + } + + fn int_greater_equal_elem( + lhs: NdArrayTensor, + rhs: Scalar, + _out_dtype: BoolDType, + ) -> NdArrayTensor { + execute_with_int_dtype!(lhs, |array| NdArrayMathOps::greater_equal_elem( + array, + rhs.elem() + )) + } + + fn int_lower(lhs: NdArrayTensor, rhs: NdArrayTensor, _out_dtype: BoolDType) -> NdArrayTensor { + execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::lower) + } + + fn int_lower_elem(lhs: NdArrayTensor, rhs: Scalar, _out_dtype: BoolDType) -> NdArrayTensor { + execute_with_int_dtype!(lhs, |array| NdArrayMathOps::lower_elem(array, rhs.elem())) + } + + fn int_lower_equal( + lhs: NdArrayTensor, + rhs: NdArrayTensor, + _out_dtype: BoolDType, + ) -> NdArrayTensor { + execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::lower_equal) + } + + fn int_lower_equal_elem( + lhs: NdArrayTensor, + rhs: Scalar, + _out_dtype: BoolDType, + ) -> NdArrayTensor { + execute_with_int_dtype!(lhs, |array| NdArrayMathOps::lower_equal_elem( + array, + rhs.elem() + )) + } + + fn int_add(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { + execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::add) + } + + fn int_add_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor { + execute_with_int_dtype!(lhs, |array| NdArrayMathOps::add_scalar(array, rhs.elem())) + } + + fn int_sub(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { + execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::sub) + } + + fn int_sub_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor { + execute_with_int_dtype!(lhs, |array| NdArrayMathOps::sub_scalar(array, rhs.elem())) + } + + fn int_mul(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { + execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::mul) + } + + fn int_mul_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor { + execute_with_int_dtype!(lhs, |array| NdArrayMathOps::mul_scalar(array, rhs.elem())) + } + + fn int_div(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { + execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::div) + } + + fn int_div_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor { + execute_with_int_dtype!(lhs, |array| NdArrayMathOps::div_scalar(array, rhs.elem())) + } + + fn int_remainder(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { + execute_with_int_dtype!((lhs, rhs), NdArrayMathOps::remainder) + } + + fn int_remainder_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor { + execute_with_int_dtype!(lhs, |array| NdArrayMathOps::remainder_scalar( + array, + rhs.elem() + )) + } + + fn int_sum(tensor: NdArrayTensor) -> NdArrayTensor { + // Use view() for zero-copy on borrowed storage + execute_with_int_dtype!(tensor, E, |array: SharedArray| NdArrayMathOps::sum_view( + array.view() + )) + } + + fn int_sum_dim(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { + execute_with_int_dtype!(tensor, |array| NdArrayMathOps::sum_dim(array, dim)) + } + + fn int_prod(tensor: NdArrayTensor) -> NdArrayTensor { + // Use view() for zero-copy on borrowed storage + execute_with_int_dtype!( + tensor, + E, + |array: SharedArray| NdArrayMathOps::prod_view(array.view()) + ) + } + + fn int_prod_dim(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { + execute_with_int_dtype!(tensor, |array| NdArrayMathOps::prod_dim(array, dim)) + } + + fn int_mean(tensor: NdArrayTensor) -> NdArrayTensor { + // Use view() for zero-copy on borrowed storage + execute_with_int_dtype!( + tensor, + E, + |array: SharedArray| NdArrayMathOps::mean_view(array.view()) + ) + } + + fn int_mean_dim(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { + execute_with_int_dtype!(tensor, |array| NdArrayMathOps::mean_dim(array, dim)) + } + + fn int_max(tensor: NdArrayTensor) -> NdArrayTensor { + // Use view() for zero-copy on borrowed storage + execute_with_int_dtype!(tensor, E, |array: SharedArray| NdArrayMathOps::max_view( + array.view() + )) + } + + fn int_min(tensor: NdArrayTensor) -> NdArrayTensor { + // Use view() for zero-copy on borrowed storage + execute_with_int_dtype!(tensor, E, |array: SharedArray| NdArrayMathOps::min_view( + array.view() + )) + } + + fn int_cumsum(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { + execute_with_int_dtype!(tensor, |array| NdArrayMathOps::cumsum(array, dim)) + } + + fn int_cumprod(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { + execute_with_int_dtype!(tensor, |array| NdArrayMathOps::cumprod(array, dim)) + } + + fn int_cummin(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { + execute_with_int_dtype!(tensor, |array| NdArrayMathOps::cummin(array, dim)) + } + + fn int_cummax(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { + execute_with_int_dtype!(tensor, |array| NdArrayMathOps::cummax(array, dim)) + } + + fn int_gather(dim: usize, tensor: NdArrayTensor, indices: NdArrayTensor) -> NdArrayTensor { + execute_with_int_dtype!(tensor, E, |array| -> NdArrayTensor { + execute_with_int_dtype!(indices, |idx_array| NdArrayOps::gather( + dim, array, idx_array + )) + }) + } + + fn int_scatter_add( + dim: usize, + tensor: NdArrayTensor, + indices: NdArrayTensor, + value: NdArrayTensor, + ) -> NdArrayTensor { + execute_with_int_dtype!((tensor, value), I, |tensor, value| -> NdArrayTensor { + execute_with_int_dtype!(indices, |idx_array| NdArrayOps::::scatter( + dim, tensor, idx_array, value + )) + }) + } + + fn int_select(tensor: NdArrayTensor, dim: usize, indices: NdArrayTensor) -> NdArrayTensor { + execute_with_int_dtype!(tensor, E, |array| -> NdArrayTensor { + execute_with_int_dtype!(indices, |idx_array| NdArrayMathOps::select( + array, dim, idx_array + )) + }) + } + + fn int_select_add( + tensor: NdArrayTensor, + dim: usize, + indices: NdArrayTensor, + value: NdArrayTensor, + ) -> NdArrayTensor { + execute_with_int_dtype!((tensor, value), I, |tensor, value| -> NdArrayTensor { + execute_with_int_dtype!(indices, |idx_array| NdArrayMathOps::::select_assign( + tensor, dim, idx_array, value + )) + }) + } + fn int_argmax(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { + // Use view() for zero-copy on borrowed storage + execute_with_int_dtype!(tensor, E, |array: SharedArray| { + NdArrayMathOps::argmax_view::(array.view(), dim) + }) + } + + fn int_argmin(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { + // Use view() for zero-copy on borrowed storage + execute_with_int_dtype!(tensor, E, |array: SharedArray| { + NdArrayMathOps::argmin_view::(array.view(), dim) + }) + } + + fn int_clamp_min(tensor: NdArrayTensor, min: Scalar) -> NdArrayTensor { + execute_with_int_dtype!(tensor, |array| NdArrayMathOps::clamp_min(array, min.elem())) + } + + fn int_clamp_max(tensor: NdArrayTensor, max: Scalar) -> NdArrayTensor { + execute_with_int_dtype!(tensor, |array| NdArrayMathOps::clamp_max(array, max.elem())) + } + + fn int_clamp(tensor: NdArrayTensor, min: Scalar, max: Scalar) -> NdArrayTensor { + execute_with_int_dtype!(tensor, |array| NdArrayMathOps::clamp( + array, + min.elem(), + max.elem() + )) + } + + fn int_abs(tensor: NdArrayTensor) -> NdArrayTensor { + match tensor.dtype() { + DType::I64 | DType::I32 | DType::I16 | DType::I8 => { + execute_with_dtype!(tensor, I, NdArrayMathOps::abs, [ + I64 => i64, I32 => i32, I16 => i16, I8 => i8 + ]) + } + // Already unsigned + DType::U64 | DType::U32 | DType::U16 | DType::U8 => tensor, + other => panic!("Unsupported dtype: {other:?}"), + } + } + + fn int_into_float(tensor: NdArrayTensor, out_dtype: FloatDType) -> FloatTensor { + execute_with_float_out_dtype!(out_dtype, F, { + execute_with_int_dtype!(tensor, IntElem, |array: SharedArray| { + array.mapv(|a: IntElem| a.elem::()).into_shared() + }) + }) + } + + fn int_swap_dims(tensor: NdArrayTensor, dim1: usize, dim2: usize) -> NdArrayTensor { + execute_with_int_dtype!(tensor, |array| NdArrayOps::swap_dims(array, dim1, dim2)) + } + + fn int_random( + shape: Shape, + distribution: Distribution, + device: &NdArrayDevice, + dtype: IntDType, + ) -> NdArrayTensor { + let mut seed = SEED.lock().unwrap(); + let mut rng = seed.take().unwrap_or_else(get_seeded_rng); + + let effective_distribution = if distribution == Distribution::Default { + Distribution::Uniform(0.0, 255.0) // Assuming UniformInt is the integer variant + } else { + distribution + }; + + let tensor = execute_with_int_out_dtype!( + dtype, + I, + Self::int_from_data( + TensorData::random::(shape, effective_distribution, &mut rng), + device, + ) + ); + *seed = Some(rng); + tensor + } + + fn int_powi(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { + execute_with_int_dtype!((lhs, rhs), I, |lhs, rhs| NdArrayMathOps::elementwise_op( + lhs, + rhs, + |a: &I, b: &I| { (a.elem::().pow(b.elem::())).elem() } + )) + } + + fn int_permute(tensor: NdArrayTensor, axes: &[usize]) -> NdArrayTensor { + execute_with_int_dtype!(tensor, |array| NdArrayOps::permute(array, axes)) + } + + fn int_flip(tensor: NdArrayTensor, axes: &[usize]) -> NdArrayTensor { + execute_with_int_dtype!(tensor, |array| NdArrayOps::flip(array, axes)) + } + + fn int_sign(tensor: NdArrayTensor) -> NdArrayTensor { + match tensor.dtype() { + DType::I64 | DType::I32 | DType::I16 | DType::I8 => { + execute_with_dtype!(tensor, I, NdArrayMathOps::sign_op, [ + I64 => i64, I32 => i32, I16 => i16, I8 => i8 + ]) + } + DType::U64 | DType::U32 | DType::U16 | DType::U8 => { + Self::int_greater_elem(tensor, 0.into(), BoolDType::Native) + } + other => panic!("Unsupported dtype: {other:?}"), + } + } + + fn int_expand(tensor: NdArrayTensor, shape: Shape) -> NdArrayTensor { + execute_with_int_dtype!(tensor, |array| NdArrayOps::expand(array, shape)) + } + + fn bitwise_and(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { + execute_with_int_dtype!((lhs, rhs), NdArrayBitOps::bitand) + } + + fn bitwise_and_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor { + execute_with_int_dtype!(lhs, |array| NdArrayBitOps::bitand_scalar(array, rhs.elem())) + } + + fn bitwise_or(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { + execute_with_int_dtype!((lhs, rhs), NdArrayBitOps::bitor) + } + + fn bitwise_or_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor { + execute_with_int_dtype!(lhs, |array| NdArrayBitOps::bitor_scalar(array, rhs.elem())) + } + + fn bitwise_xor(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { + execute_with_int_dtype!((lhs, rhs), NdArrayBitOps::bitxor) + } + + fn bitwise_xor_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor { + execute_with_int_dtype!(lhs, |array| NdArrayBitOps::bitxor_scalar(array, rhs.elem())) + } + + fn bitwise_not(tensor: NdArrayTensor) -> NdArrayTensor { + execute_with_int_dtype!(tensor, NdArrayBitOps::bitnot) + } + + fn bitwise_left_shift(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { + execute_with_int_dtype!((lhs, rhs), I, |lhs, rhs| { + NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| { + (a.elem::() << (b.elem::())).elem() + }) + }) + } + + fn bitwise_left_shift_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor { + execute_with_int_dtype!(lhs, I, |array| { + NdArrayMathOps::elementwise_op_scalar(array, |a: I| { + (a.elem::() << rhs.elem::()).elem() + }) + }) + } + + fn bitwise_right_shift(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { + execute_with_int_dtype!((lhs, rhs), I, |lhs, rhs| { + NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| { + (a.elem::() >> (b.elem::())).elem() + }) + }) + } + + fn bitwise_right_shift_scalar(lhs: NdArrayTensor, rhs: Scalar) -> NdArrayTensor { + execute_with_int_dtype!(lhs, I, |array| { + NdArrayMathOps::elementwise_op_scalar(array, |a: I| { + (a.elem::() >> rhs.elem::()).elem() + }) + }) + } + + fn int_cast(tensor: IntTensor, dtype: IntDType) -> IntTensor { + execute_with_int_dtype!(tensor, |array| cast_to_dtype(array, dtype.into())) + } + + fn int_unfold( + tensor: IntTensor, + dim: usize, + size: usize, + step: usize, + ) -> IntTensor { + execute_with_int_dtype!(tensor, |array| NdArrayOps::unfold(array, dim, size, step)) + } + + fn int_powi_scalar_impl(lhs: IntTensor, rhs: Scalar) -> IntTensor { + execute_with_int_dtype!(lhs, I, |array| { + NdArrayMathOps::elementwise_op_scalar(array, |a: I| a.powi_elem(rhs.elem())) + }) + } +} diff --git a/crates/burn/src/ops/interpolate.rs b/crates/burn/src/ops/interpolate.rs new file mode 100644 index 00000000..af9d50d1 --- /dev/null +++ b/crates/burn/src/ops/interpolate.rs @@ -0,0 +1,397 @@ +use burn_backend::ElementConversion; +use ndarray::{Array4, ArrayBase, DataOwned}; +#[cfg(not(feature = "std"))] +#[allow(unused_imports)] +use num_traits::Float; + +use crate::{FloatNdArrayElement, ShapeOps, SharedArray, UnsafeSharedRef, iter_range_par, run_par}; + +pub(crate) fn nearest_interpolate( + x: SharedArray, + output_size: [usize; 2], +) -> SharedArray { + let x = x.into_dimensionality::().unwrap(); + + let (batch_size, channels, in_height, in_width) = x.dim(); + let [out_height, out_width] = output_size; + + let y_ratio = (in_height as f64) / (out_height as f64); + let x_ratio = (in_width as f64) / (out_width as f64); + + let out_element_num = batch_size * channels * out_height * out_width; + let strides = ( + channels * out_height * out_width, + out_height * out_width, + out_width, + ); + + let mut output = Array4::zeros((batch_size, channels, out_height, out_width)); + let unsafe_shared_out = UnsafeSharedRef::new(&mut output); + + run_par!(|| { + iter_range_par!(0, out_element_num).for_each(|id| { + let (b, c, h, w) = ( + id / strides.0, + id % strides.0 / strides.1, + id % strides.1 / strides.2, + id % strides.2, + ); + + let y_in = (y_ratio * h as f64).floor() as usize; + let x_in = (x_ratio * w as f64).floor() as usize; + + unsafe { + let output = unsafe_shared_out.get(); + output[(b, c, h, w)] = x[(b, c, y_in, x_in)]; + } + }); + }); + + output.into_dyn().into_shared() +} + +pub(crate) fn nearest_interpolate_backward( + x: SharedArray, + grad: SharedArray, + output_size: [usize; 2], +) -> SharedArray { + let [batch_size, channels, input_height, input_width] = x.shape().dims(); + let [output_height, output_width] = output_size; + + let mut output_grad = + Array4::from_elem((batch_size, channels, input_height, input_width), 0.elem()); + let unsafe_shared_out = UnsafeSharedRef::new(&mut output_grad); + + run_par!(|| { + iter_range_par!(0, batch_size * channels).for_each(|k| unsafe { + let b = k / channels; + let c = k % channels; + + let output_grad = unsafe_shared_out.get(); + + for oh in 0..output_height { + for ow in 0..output_width { + let ih = start_index(oh, output_height, input_height); + let iw = start_index(ow, output_width, input_width); + + output_grad[[b, c, ih, iw]] += grad[[b, c, oh, ow]] + } + } + }) + }); + + output_grad.into_dyn().into_shared() +} + +fn start_index(output_size_index: usize, output_size: usize, input_size: usize) -> usize { + ((output_size_index as f32 * input_size as f32) / output_size as f32).floor() as usize +} + +// clamp ceil(frac) to stay within bounds in case of floating-point imprecision +pub(crate) fn ceil_clamp(frac: f64, max: usize) -> f64 { + frac.ceil().min(max as f64) +} + +pub(crate) fn bilinear_interpolate( + x: SharedArray, + output_size: [usize; 2], + align_corners: bool, +) -> SharedArray { + let x = x.into_dimensionality::().unwrap(); + + let (batch_size, channels, in_height, in_width) = x.dim(); + let [out_height, out_width] = output_size; + + let out_element_num = batch_size * channels * out_height * out_width; + let strides = ( + channels * out_height * out_width, + out_height * out_width, + out_width, + ); + + let mut output = Array4::zeros((batch_size, channels, out_height, out_width)); + let unsafe_shared_out = UnsafeSharedRef::new(&mut output); + + run_par!(|| { + iter_range_par!(0, out_element_num).for_each(|id| { + let (b, c, h, w) = ( + id / strides.0, + id % strides.0 / strides.1, + id % strides.1 / strides.2, + id % strides.2, + ); + + let (y_frac, x_frac) = if align_corners { + let y_ratio = ((in_height - 1) as f64) / (core::cmp::max(out_height - 1, 1) as f64); + let x_ratio = ((in_width - 1) as f64) / (core::cmp::max(out_width - 1, 1) as f64); + (y_ratio * h as f64, x_ratio * w as f64) + } else { + let y_frac = (h as f64 + 0.5) * (in_height as f64 / out_height as f64) - 0.5; + let x_frac = (w as f64 + 0.5) * (in_width as f64 / out_width as f64) - 0.5; + ( + y_frac.clamp(0.0, (in_height - 1) as f64), + x_frac.clamp(0.0, (in_width - 1) as f64), + ) + }; + let val = + bilinear_interpolate_single(&x, b, c, x_frac, y_frac, in_width - 1, in_height - 1); + + unsafe { + let output = unsafe_shared_out.get(); + output[(b, c, h, w)] = val.elem(); + } + }); + }); + + output.into_dyn().into_shared() +} + +pub(crate) fn bicubic_interpolate( + x: SharedArray, + output_size: [usize; 2], + align_corners: bool, +) -> SharedArray { + fn cubic_interp1d(x0: f64, x1: f64, x2: f64, x3: f64, t: f64) -> f64 { + fn cubic_convolution1(x: f64, a: f64) -> f64 { + ((a + 2.0) * x - (a + 3.0)) * x * x + 1.0 + } + + fn cubic_convolution2(x: f64, a: f64) -> f64 { + ((a * x - 5.0 * a) * x + 8.0 * a) * x - 4.0 * a + } + + let coeffs = [ + cubic_convolution2(t + 1.0, -0.75), + cubic_convolution1(t, -0.75), + cubic_convolution1(1.0 - t, -0.75), + cubic_convolution2(2.0 - t, -0.75), + ]; + + x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3] + } + + let x = x.into_dimensionality::().unwrap(); + + let (batch_size, channels, in_height, in_width) = x.dim(); + let [out_height, out_width] = output_size; + + let out_element_num = batch_size * channels * out_height * out_width; + let strides = ( + channels * out_height * out_width, + out_height * out_width, + out_width, + ); + + let mut output = Array4::zeros((batch_size, channels, out_height, out_width)); + let unsafe_shared_out = UnsafeSharedRef::new(&mut output); + + run_par!(|| { + iter_range_par!(0, out_element_num).for_each(|id| { + let (b, c, h, w) = ( + id / strides.0, + id % strides.0 / strides.1, + id % strides.1 / strides.2, + id % strides.2, + ); + + let (y_frac, x_frac) = if align_corners { + let y_ratio = ((in_height - 1) as f64) / (core::cmp::max(out_height - 1, 1) as f64); + let x_ratio = ((in_width - 1) as f64) / (core::cmp::max(out_width - 1, 1) as f64); + (y_ratio * h as f64, x_ratio * w as f64) + } else { + let y_frac = (h as f64 + 0.5) * (in_height as f64 / out_height as f64) - 0.5; + let x_frac = (w as f64 + 0.5) * (in_width as f64 / out_width as f64) - 0.5; + (y_frac, x_frac) + }; + let y0 = y_frac.floor(); + let yw = y_frac - y0; + let y_in = y0 as isize; + + let x0 = x_frac.floor(); + let xw = x_frac - x0; + let x_in = x0 as isize; + + let max_h = (in_height - 1) as isize; + let max_w = (in_width - 1) as isize; + + let ys_in = [ + (y_in - 1).clamp(0, max_h) as usize, + y_in.clamp(0, max_h) as usize, + (y_in + 1).clamp(0, max_h) as usize, + (y_in + 2).clamp(0, max_h) as usize, + ]; + + let xs_in = [ + (x_in - 1).clamp(0, max_w) as usize, + x_in.clamp(0, max_w) as usize, + (x_in + 1).clamp(0, max_w) as usize, + (x_in + 2).clamp(0, max_w) as usize, + ]; + + let coefficients = ys_in.map(|y| { + cubic_interp1d( + x[(b, c, y, xs_in[0])].elem(), + x[(b, c, y, xs_in[1])].elem(), + x[(b, c, y, xs_in[2])].elem(), + x[(b, c, y, xs_in[3])].elem(), + xw, + ) + }); + + let result = cubic_interp1d( + coefficients[0], + coefficients[1], + coefficients[2], + coefficients[3], + yw, + ) + .elem(); + + unsafe { + let output = unsafe_shared_out.get(); + output[(b, c, h, w)] = result; + } + }); + }); + + output.into_dyn().into_shared() +} + +pub(crate) fn lanczos3_interpolate( + x: SharedArray, + output_size: [usize; 2], + align_corners: bool, +) -> SharedArray { + fn lanczos3_weight(x: f64) -> f64 { + if x == 0.0 { + return 1.0; + } + let abs_x = x.abs(); + if abs_x >= 3.0 { + return 0.0; + } + let pi = core::f64::consts::PI; + let pi_x = pi * x; + let pi_x_over_3 = pi_x / 3.0; + (pi_x.sin() * pi_x_over_3.sin()) / (pi_x * pi_x_over_3) + } + + let x = x.into_dimensionality::().unwrap(); + + let (batch_size, channels, in_height, in_width) = x.dim(); + let [out_height, out_width] = output_size; + + let out_element_num = batch_size * channels * out_height * out_width; + let strides = ( + channels * out_height * out_width, + out_height * out_width, + out_width, + ); + + let mut output = Array4::zeros((batch_size, channels, out_height, out_width)); + let unsafe_shared_out = UnsafeSharedRef::new(&mut output); + + run_par!(|| { + iter_range_par!(0, out_element_num).for_each(|id| { + let (b, c, h, w) = ( + id / strides.0, + id % strides.0 / strides.1, + id % strides.1 / strides.2, + id % strides.2, + ); + + let (y_frac, x_frac) = if align_corners { + let y_ratio = ((in_height - 1) as f64) / (core::cmp::max(out_height - 1, 1) as f64); + let x_ratio = ((in_width - 1) as f64) / (core::cmp::max(out_width - 1, 1) as f64); + (y_ratio * h as f64, x_ratio * w as f64) + } else { + let y_frac = (h as f64 + 0.5) * (in_height as f64 / out_height as f64) - 0.5; + let x_frac = (w as f64 + 0.5) * (in_width as f64 / out_width as f64) - 0.5; + (y_frac, x_frac) + }; + + let y0 = y_frac.floor(); + let x0 = x_frac.floor(); + let max_h = (in_height - 1) as isize; + let max_w = (in_width - 1) as isize; + + // 6x6 separable Lanczos3 filter (skip out-of-bounds positions) + let mut result = 0.0; + let mut weight_sum = 0.0; + for ky in -2..=3 { + let yi = y0 as isize + ky; + if yi < 0 || yi > max_h { + continue; + } + let y_idx = yi as usize; + let wy = lanczos3_weight(y_frac - (y0 + ky as f64)); + for kx in -2..=3 { + let xi = x0 as isize + kx; + if xi < 0 || xi > max_w { + continue; + } + let x_idx = xi as usize; + let wx = lanczos3_weight(x_frac - (x0 + kx as f64)); + let w = wy * wx; + let pixel: f64 = x[(b, c, y_idx, x_idx)].elem(); + result += pixel * w; + weight_sum += w; + } + } + if weight_sum != 0.0 { + result /= weight_sum; + } + + unsafe { + let output = unsafe_shared_out.get(); + output[(b, c, h, w)] = result.elem(); + } + }); + }); + + output.into_dyn().into_shared() +} + +/// Sample an element of the source array with bilinear interpolation +/// +/// * `source` - The tensor to read from. Has shape (batch_size, channels, height, width) +/// * `b` - The batch to read from +/// * `c` - The channel to read from +/// * `x` - The x position to read in the array +/// * `y` - The y position to read in the array +/// * `x_max` - The max x position (inclusive) +/// * `y_max` - The max y position (inclusive) +/// +/// # Returns +/// +/// The interpolated value read from the array +pub(crate) fn bilinear_interpolate_single( + source: &ArrayBase>, + b: usize, + c: usize, + x: f64, + y: f64, + x_max: usize, + y_max: usize, +) -> f64 +where + E: FloatNdArrayElement, + S: DataOwned, +{ + let y0 = y.floor(); + let y1 = ceil_clamp(y, y_max); + let yw = y - y0; + + let x0 = x.floor(); + let x1 = ceil_clamp(x, x_max); + let xw = x - x0; + + let (x0, x1, y0, y1) = (x0 as usize, x1 as usize, y0 as usize, y1 as usize); + + let p_a = source[(b, c, y0, x0)].elem::() * (1.0 - xw) * (1.0 - yw); + let p_b = source[(b, c, y0, x1)].elem::() * xw * (1.0 - yw); + let p_c = source[(b, c, y1, x0)].elem::() * (1.0 - xw) * yw; + let p_d = source[(b, c, y1, x1)].elem::() * xw * yw; + + p_a + p_b + p_c + p_d +} diff --git a/crates/burn/src/ops/macros.rs b/crates/burn/src/ops/macros.rs new file mode 100644 index 00000000..b3ac4f94 --- /dev/null +++ b/crates/burn/src/ops/macros.rs @@ -0,0 +1,107 @@ +macro_rules! keepdim { + ( + $dim:expr, + $self:expr, + mean + ) => {{ + // Get shape first (via reference), then pass ownership to avoid clone + let mut shape = $self.shape().into_shape(); + shape[$dim] = 1; + let tensor: SharedArray = mean_dim($self, $dim); + NdArrayOps::reshape(tensor, shape) + }}; + ( + $dim:expr, + $self:expr, + sum + ) => {{ + // Get shape first (via reference), then pass ownership to avoid clone + let mut shape = $self.shape().into_shape(); + shape[$dim] = 1; + let tensor: SharedArray = sum_dim($self, $dim); + NdArrayOps::reshape(tensor, shape) + }}; + ( + $dim:expr, + $self:expr, + prod + ) => {{ + // Get shape first (via reference), then pass ownership to avoid clone + let mut shape = $self.shape().into_shape(); + shape[$dim] = 1; + let tensor: SharedArray = prod_dim($self, $dim); + NdArrayOps::reshape(tensor, shape) + }}; +} + +use burn_backend::ElementConversion; +pub(crate) use keepdim; +use ndarray::{Axis, Zip}; + +use crate::{SharedArray, element::NdArrayElement}; + +pub(crate) fn mean_dim(tensor: SharedArray, dim: usize) -> SharedArray { + tensor.mean_axis(Axis(dim)).unwrap().into_shared() +} + +pub(crate) fn sum_dim(tensor: SharedArray, dim: usize) -> SharedArray { + tensor.sum_axis(Axis(dim)).into_shared() +} + +pub(crate) fn prod_dim(tensor: SharedArray, dim: usize) -> SharedArray { + tensor + .fold_axis(Axis(dim), 1.elem::(), |acc, &x| acc.mul(x.elem())) + .into_shared() +} + +/// Generic cumulative operation function with closure-based operation. +pub(crate) fn cumulative_with_op(tensor: SharedArray, dim: usize, op: F) -> SharedArray +where + E: NdArrayElement, + F: Fn(&mut E, &E), +{ + let axis = Axis(dim); + let shape = tensor.shape().to_vec(); + // Use into_owned() instead of to_owned() - only copies if shared, avoids copy if unique + let mut result = tensor.into_owned(); + let dim_size = shape[dim]; + + for i in 1..dim_size { + let prev = result.index_axis(axis, i - 1).to_owned(); + let mut current = result.index_axis_mut(axis, i); + Zip::from(&mut current).and(&prev).for_each(&op); + } + + result.into_shared() +} + +// Define all cumulative operation functions using the generic function +pub(crate) fn cumsum_dim(tensor: SharedArray, dim: usize) -> SharedArray { + cumulative_with_op(tensor, dim, |c, &p| *c = c.add(p.elem())) +} + +pub(crate) fn cumprod_dim(tensor: SharedArray, dim: usize) -> SharedArray { + cumulative_with_op(tensor, dim, |c, &p| *c = c.mul(p.elem())) +} + +pub(crate) fn cummin_dim>( + tensor: SharedArray, + dim: usize, +) -> SharedArray { + cumulative_with_op(tensor, dim, |c, &p| { + if p < *c { + *c = p; + } + }) +} + +pub(crate) fn cummax_dim>( + tensor: SharedArray, + dim: usize, +) -> SharedArray { + cumulative_with_op(tensor, dim, |c, &p| { + if p > *c { + *c = p; + } + }) +} diff --git a/crates/burn/src/ops/matmul.rs b/crates/burn/src/ops/matmul.rs new file mode 100644 index 00000000..3fb7b467 --- /dev/null +++ b/crates/burn/src/ops/matmul.rs @@ -0,0 +1,362 @@ +use crate::UnsafeSharedRef; +use crate::{NdArrayElement, ShapeOps, SharedArray, iter_range_par, ops::NdArrayOps, run_par}; + +use alloc::{vec, vec::Vec}; +use burn_backend::ElementConversion; +use burn_backend::Shape; +use ndarray::{IxDyn, s}; + +pub(crate) fn matmul( + lhs: SharedArray, + rhs: SharedArray, +) -> SharedArray { + let shape_lhs = lhs.shape(); + let shape_rhs = rhs.shape(); + let ndims = shape_lhs.num_dims(); + let m = shape_lhs[ndims - 2]; // # of left rows + let k = shape_rhs[ndims - 2]; // # of left cols and right rows + let n = shape_rhs[ndims - 1]; // # of right cols + + let (out_shape, strides_lhs, strides_rhs, strides_out) = output_shape(shape_lhs, shape_rhs); + let l_mat_size = m * k; // size of matrix component of left array + let r_mat_size = k * n; // size of matrix component of right array + let out_mat_size = m * n; // size of matrix component of output array + + let num_l_batches = shape_lhs.num_elements() / l_mat_size; + let num_r_batches = shape_rhs.num_elements() / r_mat_size; + let num_out_batches = out_shape.num_elements() / out_mat_size; + + let lhs_array = NdArrayOps::reshape(lhs, Shape::new([num_l_batches, m, k])); + let rhs_array = NdArrayOps::reshape(rhs, Shape::new([num_r_batches, k, n])); + + let alpha: E = 1.0.elem(); + let beta: E = 0.0.elem(); + + let out = run_par!(|| { + let mut out_array = ndarray::Array3::::zeros((num_out_batches, m, n)); + let unsafe_shared_out_array = UnsafeSharedRef::new(&mut out_array); + + iter_range_par!(0, num_out_batches).for_each(|out_batch| { + // Here, we: + // 1. Un-flatten the output batch into a component-based batch index. + // 2. Use the strides for left and right batch indices to convert it to a flattened + // batch for left and right. + let out_index = strides_out.unflatten(out_batch); + let l_batch = strides_lhs.flatten(&out_index); + let r_batch = strides_rhs.flatten(&out_index); + + let lhs_slice = lhs_array.slice(s!(l_batch, .., ..)); + let rhs_slice = rhs_array.slice(s!(r_batch, .., ..)); + + unsafe { + let mut out_slice = unsafe_shared_out_array + .get() + .slice_mut(s!(out_batch, .., ..)); + + ndarray::linalg::general_mat_mul( + alpha, + &lhs_slice, + &rhs_slice, + beta, + &mut out_slice, + ) + } + }); + + out_array.into_shared().into_dyn() + }); + + NdArrayOps::reshape(out, out_shape) +} + +#[derive(Debug, PartialEq)] +struct Strides { + strides: Vec, +} +impl Strides { + fn new(strides: Vec) -> Self { + Strides { strides } + } + + fn unflatten(&self, linear_index: usize) -> Vec { + let mut coord = Vec::with_capacity(self.strides.len()); + let mut rem = linear_index; + for stride in self.strides.iter() { + coord.push(rem / stride); + rem %= stride; + } + coord + } + + fn flatten(&self, index: &Vec) -> usize { + assert_eq!(self.strides.len(), index.len()); + self.strides + .iter() + .zip(index) + .map(|(stride, index)| stride * index) + .sum() + } +} + +/// Compute the (broadcasted) output shape of matrix multiplication, along with strides for +/// the non-matrix dimensions of all arrays. +/// +/// # Arguments +/// * `lsh`: Shape of the first (left-hand) matrix multiplication argument. +/// * `rsh`: Shape of the second (right-hand) matrix multiplication argument. +/// +/// # Panics +/// * If `D` is not at least 2. +/// * If the matrix multiplication dimensions (last 2) are incompatible. +/// * If any other dimension is not the same for both tensors, or equal to 1. (Any dimension where +/// one dim is equal to 1 is broadcast.) +fn output_shape(lsh: &[usize], rsh: &[usize]) -> (Shape, Strides, Strides, Strides) { + let ndims = lsh.num_dims(); + if ndims < 2 { + panic!("Matrix multiplication requires an array with at least 2 dimensions."); + } + + // Fetch matrix dimensions and check compatibility. + let l_rows = lsh[ndims - 2]; + let l_cols = lsh[ndims - 1]; + let r_rows = rsh[ndims - 2]; + let r_cols = rsh[ndims - 1]; + if l_cols != r_rows { + panic!("Dimensions are incompatible for matrix multiplication."); + } + // Set matrix dimensions of the output shape. + let mut osh = vec![0; ndims]; + osh[ndims - 2] = l_rows; + osh[ndims - 1] = r_cols; + + // Set other array dimensions, broadcasting as necessary. + // Compute the strides inline. + let mut cur_l_stride: usize = 1; + let mut cur_r_stride: usize = 1; + let mut cur_o_stride: usize = 1; + let mut l_strides = Vec::with_capacity(ndims - 2); + let mut r_strides = Vec::with_capacity(ndims - 2); + let mut o_strides = Vec::with_capacity(ndims - 2); + for i in (0..ndims - 2).rev() { + let l_dim = lsh[i]; + let r_dim = rsh[i]; + + // Compatible dimensions are: + // 1. Both dimensions are equal. + // 2. One of the dimensions is equal to 1. + let o_dim: usize; + if l_dim == r_dim { + o_dim = l_dim; // both dimensions are equal + l_strides.push(cur_l_stride); + r_strides.push(cur_r_stride); + } else if l_dim == 1 { + o_dim = r_dim; // broadcast the left + l_strides.push(0); + r_strides.push(cur_r_stride); + } else if r_dim == 1 { + o_dim = l_dim; // broadcast the right + l_strides.push(cur_l_stride); + r_strides.push(0); + } else { + panic!("Dimensions differ and cannot be broadcasted."); + } + osh[i] = o_dim; + o_strides.push(cur_o_stride); + cur_o_stride *= o_dim; + + cur_l_stride *= l_dim; + cur_r_stride *= r_dim; + } + l_strides.reverse(); + r_strides.reverse(); + o_strides.reverse(); + + ( + Shape::from(osh), + Strides::new(l_strides), + Strides::new(r_strides), + Strides::new(o_strides), + ) +} + +pub(crate) fn cross( + lhs: SharedArray, + rhs: SharedArray, + dim: usize, +) -> SharedArray { + let shape_lhs = lhs.shape(); + let shape_rhs = rhs.shape(); + let ndims = shape_lhs.num_dims(); + + // Broadcast the shapes except along dim + let mut broadcast_shape = vec![0; ndims]; + for i in 0..ndims { + if i == dim { + broadcast_shape[i] = shape_lhs[i]; // already checked to be 3 + } else { + let l = shape_lhs[i]; + let r = shape_rhs[i]; + if l == r { + broadcast_shape[i] = l; + } else if l == 1 { + broadcast_shape[i] = r; + } else if r == 1 { + broadcast_shape[i] = l; + } else { + panic!("Tensors are not broadcastable along dimension {}", i); + } + } + } + + // Broadcast lhs and rhs + let lhs_broadcast = if shape_lhs == broadcast_shape.as_slice() { + lhs + } else { + NdArrayOps::expand(lhs, Shape::from(broadcast_shape.clone())) + }; + let rhs_broadcast = if shape_rhs == broadcast_shape.as_slice() { + rhs + } else { + NdArrayOps::expand(rhs, Shape::from(broadcast_shape.clone())) + }; + + // Now, move dim to the last dimension + let mut perm = (0..ndims).collect::>(); + perm.remove(dim); + perm.push(dim); + + let lhs_permuted = NdArrayOps::permute(lhs_broadcast, &perm); + let rhs_permuted = NdArrayOps::permute(rhs_broadcast, &perm); + + // Reshape to (*, 3) + let total_elements = lhs_permuted.shape().num_elements(); + let batch_size = total_elements / 3; + let lhs_reshaped = NdArrayOps::reshape(lhs_permuted, Shape::new([batch_size, 3])); + let rhs_reshaped = NdArrayOps::reshape(rhs_permuted, Shape::new([batch_size, 3])); + + // Compute cross product + let mut result = ndarray::ArrayD::::zeros(IxDyn(&[batch_size, 3])); + for i in 0..batch_size { + let a1 = lhs_reshaped[IxDyn(&[i, 0])]; + let a2 = lhs_reshaped[IxDyn(&[i, 1])]; + let a3 = lhs_reshaped[IxDyn(&[i, 2])]; + let b1 = rhs_reshaped[IxDyn(&[i, 0])]; + let b2 = rhs_reshaped[IxDyn(&[i, 1])]; + let b3 = rhs_reshaped[IxDyn(&[i, 2])]; + result[IxDyn(&[i, 0])] = a2.mul(b3).sub(a3.mul(b2)); + result[IxDyn(&[i, 1])] = a3.mul(b1).sub(a1.mul(b3)); + result[IxDyn(&[i, 2])] = a1.mul(b2).sub(a2.mul(b1)); + } + + let result_shared = result.into_shared(); + + // Reshape back to the broadcast shape with dim at the end + let mut result_shape = broadcast_shape; + result_shape.remove(dim); + result_shape.push(3); + let result_reshaped = NdArrayOps::reshape(result_shared, Shape::from(result_shape)); + + // Permute back + let mut inv_perm = vec![0; ndims]; + for (i, &p) in perm.iter().enumerate() { + inv_perm[p] = i; + } + NdArrayOps::permute(result_reshaped, &inv_perm) +} + +#[cfg(test)] +mod tests { + use super::*; + + impl Strides { + fn empty() -> Self { + Strides { + strides: Vec::with_capacity(0), + } + } + } + + #[test] + fn test_output_shape() { + // plain matrix multiply + assert_eq!( + output_shape(&[5, 3], &[3, 7]), + ( + Shape::from([5, 7]), + Strides::empty(), + Strides::empty(), + Strides::empty() + ) + ); + // matrix multiply with one extra stack dimension + assert_eq!( + output_shape(&[4, 5, 3], &[4, 3, 7]), + ( + Shape::from([4, 5, 7]), + Strides::new(vec![1]), + Strides::new(vec![1]), + Strides::new(vec![1]) + ) + ); + // rank 3, broadcast left + assert_eq!( + output_shape(&[1, 5, 3], &[4, 3, 7]), + ( + Shape::from([4, 5, 7]), + Strides::new(vec![0]), + Strides::new(vec![1]), + Strides::new(vec![1]) + ) + ); + // rank 3, broadcast right + assert_eq!( + output_shape(&[4, 5, 3], &[1, 3, 7]), + ( + Shape::from([4, 5, 7]), + Strides::new(vec![1]), + Strides::new(vec![0]), + Strides::new(vec![1]) + ) + ); + // rank 4, multi broadcast + assert_eq!( + output_shape(&[1, 4, 5, 3], &[8, 1, 3, 7]), + ( + Shape::from([8, 4, 5, 7]), + Strides::new(vec![0, 1]), + Strides::new(vec![1, 0]), + Strides::new(vec![4, 1]) + ) + ); + // rank 5, multi-broadcast + assert_eq!( + output_shape(&[1, 3, 4, 5, 3], &[8, 3, 1, 3, 7]), + ( + Shape::from([8, 3, 4, 5, 7]), + Strides::new(vec![0, 4, 1]), + Strides::new(vec![3, 1, 0]), + Strides::new(vec![12, 4, 1]) + ) + ) + } + + #[test] + #[should_panic( + expected = "Matrix multiplication requires an array with at least 2 dimensions." + )] + fn test_output_shape_too_small() { + output_shape(&[4], &[4]); + } + + #[test] + #[should_panic(expected = "Dimensions are incompatible for matrix multiplication.")] + fn test_output_shape_bad_matrix_dims() { + output_shape(&[5, 3], &[4, 7]); + } + + #[test] + #[should_panic(expected = "Dimensions differ and cannot be broadcasted.")] + fn test_output_shape_non_broadcast() { + output_shape(&[4, 5, 3], &[2, 3, 7]); + } +} diff --git a/crates/burn/src/ops/maxpool.rs b/crates/burn/src/ops/maxpool.rs new file mode 100644 index 00000000..2a162cf9 --- /dev/null +++ b/crates/burn/src/ops/maxpool.rs @@ -0,0 +1,247 @@ +use crate::{ + ShapeOps, SharedArray, + element::{FloatNdArrayElement, IntNdArrayElement}, + iter_range_par, + ops::padding::apply_padding_4d, + run_par, + sharing::UnsafeSharedRef, +}; + +use burn_backend::ElementConversion; +use burn_backend::ops::conv::calculate_pool_output_size; +use ndarray::Array4; + +pub(crate) fn max_pool2d( + x: SharedArray, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], + ceil_mode: bool, +) -> SharedArray { + let [kernel_height, kernel_width] = kernel_size; + let [padding_height, padding_width] = padding; + let [stride_height, stride_width] = stride; + let [dilation_height, dilation_width] = dilation; + let [batch_size, channels, x_height, x_width] = x.shape().dims(); + let inf = (-f32::INFINITY).elem::(); + + let out_height = calculate_pool_output_size( + kernel_height, + stride_height, + padding_height, + dilation_height, + x_height, + ceil_mode, + ); + let out_width = calculate_pool_output_size( + kernel_width, + stride_width, + padding_width, + dilation_width, + x_width, + ceil_mode, + ); + + // Calculate extra padding needed for ceil_mode + // The maximum input position accessed is: (out_size - 1) * stride + (kernel_size - 1) * dilation + // This must be < input_size + 2 * total_padding + let max_ih = + (out_height.saturating_sub(1)) * stride_height + (kernel_height - 1) * dilation_height; + let max_iw = (out_width.saturating_sub(1)) * stride_width + (kernel_width - 1) * dilation_width; + let padded_height = x_height + 2 * padding_height; + let padded_width = x_width + 2 * padding_width; + let extra_pad_h = max_ih.saturating_sub(padded_height.saturating_sub(1)); + let extra_pad_w = max_iw.saturating_sub(padded_width.saturating_sub(1)); + let total_padding = [padding_height + extra_pad_h, padding_width + extra_pad_w]; + + let x = apply_padding_4d::(x, total_padding, inf); + + // Offset to account for extra padding (extra_pad is added on both sides by apply_padding_4d) + let offset_h = extra_pad_h; + let offset_w = extra_pad_w; + + let mut output = Array4::from_elem((batch_size, channels, out_height, out_width), inf); + let unsafe_shared_out = UnsafeSharedRef::new(&mut output); + + run_par!(|| { + iter_range_par!(0, batch_size * channels).for_each(|k| unsafe { + let b = k / channels; + let c = k % channels; + + let output = unsafe_shared_out.get(); + + for oh in 0..out_height { + for ow in 0..out_width { + let mut max_val = inf; + + for kh in 0..kernel_height { + let ih = offset_h + oh * stride_height + kh * dilation_height; + + for kw in 0..kernel_width { + let iw = offset_w + ow * stride_width + kw * dilation_width; + + let val = x[[b, c, ih, iw]]; + + if val > max_val { + max_val = val; + } + } + } + + output[[b, c, oh, ow]] = max_val; + } + } + }) + }); + + output.into_dyn().into_shared() +} + +pub(crate) fn max_pool2d_with_indices( + x: SharedArray, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], + ceil_mode: bool, +) -> (SharedArray, SharedArray) { + let [kernel_height, kernel_width] = kernel_size; + let [padding_height, padding_width] = padding; + let [stride_height, stride_width] = stride; + let [dilation_height, dilation_width] = dilation; + let [batch_size, channels, x_height, x_width] = x.shape().dims(); + let inf = (-f32::INFINITY).elem::(); + + let out_height = calculate_pool_output_size( + kernel_height, + stride_height, + padding_height, + dilation_height, + x_height, + ceil_mode, + ); + let out_width = calculate_pool_output_size( + kernel_width, + stride_width, + padding_width, + dilation_width, + x_width, + ceil_mode, + ); + + // Calculate extra padding needed for ceil_mode + let max_ih = + (out_height.saturating_sub(1)) * stride_height + (kernel_height - 1) * dilation_height; + let max_iw = (out_width.saturating_sub(1)) * stride_width + (kernel_width - 1) * dilation_width; + let padded_height = x_height + 2 * padding_height; + let padded_width = x_width + 2 * padding_width; + let extra_pad_h = max_ih.saturating_sub(padded_height.saturating_sub(1)); + let extra_pad_w = max_iw.saturating_sub(padded_width.saturating_sub(1)); + let total_padding = [padding_height + extra_pad_h, padding_width + extra_pad_w]; + + let x = apply_padding_4d::(x, total_padding, inf); + + // Offset to account for extra padding + let offset_h = extra_pad_h; + let offset_w = extra_pad_w; + + let mut output = Array4::from_elem((batch_size, channels, out_height, out_width), inf); + let mut indices = Array4::::zeros((batch_size, channels, out_height, out_width)); + + let unsafe_shared_out = UnsafeSharedRef::new(&mut output); + let unsafe_shared_indices = UnsafeSharedRef::new(&mut indices); + + run_par!(|| { + iter_range_par!(0, batch_size * channels).for_each(|k| unsafe { + let b = k / channels; + let c = k % channels; + + let output = unsafe_shared_out.get(); + let indices = unsafe_shared_indices.get(); + + for oh in 0..out_height { + for ow in 0..out_width { + let mut max_val = inf; + let mut index = 0; + + for kh in 0..kernel_height { + let ih = offset_h + oh * stride_height + kh * dilation_height; + + for kw in 0..kernel_width { + let iw = offset_w + ow * stride_width + kw * dilation_width; + let val = x[[b, c, ih, iw]]; + + if val > max_val { + max_val = val; + + // Calculate index in original (unpadded) input + let ih_orig = ih as i64 - (total_padding[0]) as i64; + let iw_orig = iw as i64 - (total_padding[1]) as i64; + + // Clamp to valid range for index calculation + let ih_clamped = ih_orig.max(0).min(x_height as i64 - 1); + let iw_clamped = iw_orig.max(0).min(x_width as i64 - 1); + + index = ih_clamped * x_width as i64 + iw_clamped; + } + } + } + + output[[b, c, oh, ow]] = max_val; + indices[[b, c, oh, ow]] = index.elem(); + } + } + }) + }); + + let output = output.into_dyn().into_shared(); + let indices = indices.into_dyn().into_shared(); + + (output, indices) +} + +#[allow(clippy::too_many_arguments)] +pub(crate) fn max_pool2d_backward( + x: SharedArray, + _kernel_size: [usize; 2], + _stride: [usize; 2], + _padding: [usize; 2], + _dilation: [usize; 2], + _ceil_mode: bool, + output_grad: SharedArray, + indices: SharedArray, +) -> SharedArray { + let [_batch_size, _channels, height, width] = output_grad.shape().dims(); + let [batch_size, channels, height_x, width_x] = x.shape().dims(); + + let output_grad = output_grad; + let indices = indices; + + let mut output = Array4::zeros((batch_size, channels, height_x, width_x)); + + let unsafe_shared_out = UnsafeSharedRef::new(&mut output); + + run_par!(|| { + iter_range_par!(0, batch_size * channels).for_each(|k| unsafe { + let b = k / channels; + let c = k % channels; + + let output = unsafe_shared_out.get(); + + for h in 0..height { + for w in 0..width { + let index = indices[[b, c, h, w]].elem::(); + let grad = output_grad[[b, c, h, w]]; + + let index_h = index as usize / width_x; + let index_w = index as usize % width_x; + + output[[b, c, index_h, index_w]] += grad; + } + } + }); + }); + + output.into_dyn().into_shared() +} diff --git a/crates/burn/src/ops/mod.rs b/crates/burn/src/ops/mod.rs new file mode 100644 index 00000000..f4f215ec --- /dev/null +++ b/crates/burn/src/ops/mod.rs @@ -0,0 +1,24 @@ +mod activation; +mod base; +mod bool_tensor; +mod int_tensor; +mod module; +mod qtensor; +#[cfg(feature = "simd")] +mod simd; +mod tensor; +mod transaction; + +pub(crate) mod adaptive_avgpool; +pub(crate) mod avgpool; +pub(crate) mod conv; +pub(crate) mod deform_conv; +pub(crate) mod grid_sample; +pub(crate) mod interpolate; +pub(crate) mod macros; +pub(crate) mod matmul; +pub(crate) mod maxpool; +pub(crate) mod padding; +pub(crate) mod quantization; + +pub(crate) use base::*; diff --git a/crates/burn/src/ops/module.rs b/crates/burn/src/ops/module.rs new file mode 100644 index 00000000..a7d7e27a --- /dev/null +++ b/crates/burn/src/ops/module.rs @@ -0,0 +1,381 @@ +use super::{ + adaptive_avgpool::{adaptive_avg_pool2d, adaptive_avg_pool2d_backward}, + avgpool::{avg_pool2d, avg_pool2d_backward}, + conv::{conv_transpose2d, conv_transpose3d, conv2d, conv3d}, + deform_conv::{backward::deform_conv2d_backward, deform_conv2d}, + interpolate::{ + bicubic_interpolate, bilinear_interpolate, lanczos3_interpolate, nearest_interpolate, + }, + maxpool::{max_pool2d, max_pool2d_backward, max_pool2d_with_indices}, +}; +#[cfg(feature = "simd")] +use crate::ops::simd::{ + avgpool::try_avg_pool2d_simd, conv::try_conv2d_simd, maxpool::try_max_pool2d_simd, +}; +use crate::{ + NdArray, SharedArray, element::FloatNdArrayElement, execute_with_int_dtype, + tensor::NdArrayTensor, +}; +use crate::{ + element::{IntNdArrayElement, QuantElement}, + ops::interpolate::nearest_interpolate_backward, +}; +use burn_backend::{ + ElementConversion, TensorMetadata, + ops::{attention::attention_fallback, *}, + tensor::FloatTensor, +}; + +macro_rules! module_op { + // Module op with inputs (inp), optional (opt) and arguments (args). + // Converts NdArrayStorage to SharedArray for compatibility with existing operations. + (inp($($x:tt),+), opt($($opt:tt),*), $element:ident, $op:expr) => {{ + #[allow(unused_parens, unreachable_patterns)] + match ($($x),+) { + ($(NdArrayTensor::F32($x)),+) => { + type $element = f32; + $op( + $($x.into_shared()),+ + $(, $opt.map(|o| match o { NdArrayTensor::F32(val) => val.into_shared(), _ => panic!("Optional argument type mismatch") }))* + ) + } + ($(NdArrayTensor::F64($x)),+) => { + type $element = f64; + $op( + $($x.into_shared()),+ + $(, $opt.map(|o| match o { NdArrayTensor::F64(val) => val.into_shared(), _ => panic!("Optional argument type mismatch") }))* + ) + } + _ => panic!("Data type mismatch"), + } + }}; +} + +impl ModuleOps + for NdArray +where + NdArrayTensor: From>, + NdArrayTensor: From>, +{ + fn conv2d( + x: NdArrayTensor, + weight: NdArrayTensor, + bias: Option, + options: ConvOptions<2>, + ) -> NdArrayTensor { + module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| { + #[cfg(feature = "simd")] + let (x, weight, bias) = match try_conv2d_simd(x, weight, bias, options.clone()) { + Ok(out) => return out.into(), + Err(args) => args, + }; + conv2d::(x, weight, bias, options).into() + }) + } + + fn deform_conv2d( + x: FloatTensor, + offset: FloatTensor, + weight: FloatTensor, + mask: Option>, + bias: Option>, + options: DeformConvOptions<2>, + ) -> FloatTensor { + module_op!( + inp(x, offset, weight), + opt(mask, bias), + E, + |x, offset, weight, mask, bias| deform_conv2d::( + x, offset, weight, mask, bias, options + ) + .into() + ) + } + + fn deform_conv2d_backward( + x: FloatTensor, + offset: FloatTensor, + weight: FloatTensor, + mask: Option>, + bias: Option>, + output_grad: FloatTensor, + options: DeformConvOptions<2>, + ) -> DeformConv2dBackward { + module_op!( + inp(x, offset, weight, output_grad), + opt(mask, bias), + E, + |x, offset, weight, output_grad, mask, bias| { + let (x, offset, weight, mask, bias) = deform_conv2d_backward::( + x, + offset, + weight, + mask, + bias, + output_grad, + options, + ); + DeformConv2dBackward::new( + x.into(), + offset.into(), + weight.into(), + mask.map(|m| m.into()), + bias.map(|b| b.into()), + ) + } + ) + } + + fn conv_transpose2d( + x: FloatTensor, + weight: FloatTensor, + bias: Option>, + options: ConvTransposeOptions<2>, + ) -> FloatTensor { + module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| { + conv_transpose2d::(x, weight, bias, options).into() + }) + } + + fn avg_pool2d( + x: FloatTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + count_include_pad: bool, + ceil_mode: bool, + ) -> FloatTensor { + module_op!(inp(x), opt(), E, |x| { + #[cfg(feature = "simd")] + let x = match if ceil_mode { + // SIMD path doesn't support ceil_mode yet, skip it + Err(x) + } else { + try_avg_pool2d_simd(x, kernel_size, stride, padding, count_include_pad) + } { + Ok(out) => return out.into(), + Err(x) => x, + }; + avg_pool2d::( + x, + kernel_size, + stride, + padding, + count_include_pad, + ceil_mode, + ) + .into() + }) + } + + fn avg_pool2d_backward( + x: FloatTensor, + grad: FloatTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + count_include_pad: bool, + ceil_mode: bool, + ) -> FloatTensor { + module_op!(inp(x, grad), opt(), E, |x, grad| avg_pool2d_backward::( + x, + grad, + kernel_size, + stride, + padding, + count_include_pad, + ceil_mode + ) + .into()) + } + + fn max_pool2d( + x: FloatTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], + ceil_mode: bool, + ) -> FloatTensor { + module_op!(inp(x), opt(), E, |x| { + #[cfg(feature = "simd")] + let x = match if ceil_mode { + // SIMD path doesn't support ceil_mode yet, skip it + Err(x) + } else { + try_max_pool2d_simd(x, kernel_size, stride, padding, dilation) + } { + Ok(out) => return out.into(), + Err(x) => x, + }; + max_pool2d::(x, kernel_size, stride, padding, dilation, ceil_mode).into() + }) + } + + fn max_pool2d_with_indices( + x: FloatTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], + ceil_mode: bool, + ) -> MaxPool2dWithIndices> { + module_op!(inp(x), opt(), E, |x| { + let (output, indices) = max_pool2d_with_indices::( + x, + kernel_size, + stride, + padding, + dilation, + ceil_mode, + ); + MaxPool2dWithIndices::new(output.into(), indices.into()) + }) + } + + fn max_pool2d_with_indices_backward( + x: FloatTensor, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], + ceil_mode: bool, + output_grad: FloatTensor, + indices: NdArrayTensor, + ) -> MaxPool2dBackward> { + execute_with_int_dtype!(indices, IntElem, |idx_s: SharedArray| { + // Convert indices from runtime dtype to the expected I type + // (pool indices are bounded by tensor dimensions, so conversion is safe) + let indices: SharedArray = idx_s.mapv(|x| x.elem()).into_shared(); + module_op!(inp(x, output_grad), opt(), E, |x, output_grad| { + let output = max_pool2d_backward::( + x, + kernel_size, + stride, + padding, + dilation, + ceil_mode, + output_grad, + indices, + ); + MaxPool2dBackward::new(output.into()) + }) + }) + } + + fn adaptive_avg_pool2d(x: FloatTensor, output_size: [usize; 2]) -> FloatTensor { + module_op!(inp(x), opt(), E, |x| adaptive_avg_pool2d::( + x, + output_size + ) + .into()) + } + + fn adaptive_avg_pool2d_backward( + x: FloatTensor, + grad: FloatTensor, + ) -> FloatTensor { + module_op!(inp(x, grad), opt(), E, |x, grad| { + adaptive_avg_pool2d_backward::(x, grad).into() + }) + } + + fn interpolate( + x: FloatTensor, + output_size: [usize; 2], + options: InterpolateOptions, + ) -> FloatTensor { + match options.mode { + InterpolateMode::Nearest => { + module_op!(inp(x), opt(), E, |x| nearest_interpolate::( + x, + output_size + ) + .into()) + } + InterpolateMode::Bilinear => { + let align_corners = options.align_corners; + module_op!(inp(x), opt(), E, |x| bilinear_interpolate::( + x, + output_size, + align_corners + ) + .into()) + } + InterpolateMode::Bicubic => { + let align_corners = options.align_corners; + module_op!(inp(x), opt(), E, |x| bicubic_interpolate::( + x, + output_size, + align_corners + ) + .into()) + } + InterpolateMode::Lanczos3 => { + let align_corners = options.align_corners; + module_op!(inp(x), opt(), E, |x| lanczos3_interpolate::( + x, + output_size, + align_corners + ) + .into()) + } + } + } + + fn interpolate_backward( + x: FloatTensor, + grad: FloatTensor, + output_size: [usize; 2], + options: InterpolateOptions, + ) -> FloatTensor { + match options.mode { + InterpolateMode::Nearest => module_op!(inp(x, grad), opt(), E, |x, grad| { + nearest_interpolate_backward::(x, grad, output_size).into() + }), + InterpolateMode::Bilinear => { + panic!("bilinear interpolation backward is not supported for ndarray backend") + } + InterpolateMode::Bicubic => { + panic!("bicubic interpolation backward is not supported for ndarray backend") + } + InterpolateMode::Lanczos3 => { + panic!("lanczos3 interpolation backward is not supported for ndarray backend") + } + } + } + + fn conv3d( + x: FloatTensor, + weight: FloatTensor, + bias: Option>, + options: ConvOptions<3>, + ) -> FloatTensor { + module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| conv3d::( + x, weight, bias, options + ) + .into()) + } + + fn conv_transpose3d( + x: FloatTensor, + weight: FloatTensor, + bias: Option>, + options: ConvTransposeOptions<3>, + ) -> FloatTensor { + module_op!(inp(x, weight), opt(bias), E, |x, weight, bias| { + conv_transpose3d::(x, weight, bias, options).into() + }) + } + + fn attention( + query: FloatTensor, + key: FloatTensor, + value: FloatTensor, + mask: Option>, + attn_bias: Option>, + options: AttentionModuleOptions, + ) -> FloatTensor { + attention_fallback::(query, key, value, mask, attn_bias, options) + } +} diff --git a/crates/burn/src/ops/padding.rs b/crates/burn/src/ops/padding.rs new file mode 100644 index 00000000..d9c6fd3a --- /dev/null +++ b/crates/burn/src/ops/padding.rs @@ -0,0 +1,72 @@ +use crate::{NdArrayElement, SharedArray}; +use ndarray::{Array4, Array5}; + +use super::NdArrayOps; + +pub(crate) fn apply_padding_4d( + x: SharedArray, + padding: [usize; 2], + elem: E, +) -> SharedArray { + let [batch_size, input_channels, height, width] = x.shape().try_into().unwrap(); + let [padding_height, padding_width] = padding; + let padded_height = height + 2 * padding_height; + let padded_width = width + 2 * padding_width; + + let x_new = Array4::from_elem( + (batch_size, input_channels, padded_height, padded_width), + elem, + ); + let mut x_new = x_new.into_shared().into_dyn(); + + x_new = NdArrayOps::slice_assign( + x_new, + &[ + burn_backend::Slice::from(0..batch_size), + burn_backend::Slice::from(0..input_channels), + burn_backend::Slice::from(padding_height..height + padding_height), + burn_backend::Slice::from(padding_width..width + padding_width), + ], + x, + ); + + x_new +} + +pub(crate) fn apply_padding_5d( + x: SharedArray, + padding: [usize; 3], + elem: E, +) -> SharedArray { + let [batch_size, input_channels, depth, height, width] = x.shape().try_into().unwrap(); + let [padding_depth, padding_height, padding_width] = padding; + let padded_depth = depth + 2 * padding_depth; + let padded_height = height + 2 * padding_height; + let padded_width = width + 2 * padding_width; + + let x_new = Array5::from_elem( + ( + batch_size, + input_channels, + padded_depth, + padded_height, + padded_width, + ), + elem, + ); + let mut x_new = x_new.into_shared().into_dyn(); + + x_new = NdArrayOps::slice_assign( + x_new, + &[ + burn_backend::Slice::from(0..batch_size), + burn_backend::Slice::from(0..input_channels), + burn_backend::Slice::from(padding_depth..depth + padding_depth), + burn_backend::Slice::from(padding_height..height + padding_height), + burn_backend::Slice::from(padding_width..width + padding_width), + ], + x, + ); + + x_new +} diff --git a/crates/burn/src/ops/qtensor.rs b/crates/burn/src/ops/qtensor.rs new file mode 100644 index 00000000..a7210fc8 --- /dev/null +++ b/crates/burn/src/ops/qtensor.rs @@ -0,0 +1,353 @@ +use alloc::{vec, vec::Vec}; + +use burn_backend::{ + DType, ExecutionError, Shape, TensorData, TensorMetadata, + ops::QTensorOps, + quantization::{ + QParams, QuantLevel, QuantMode, QuantScheme, QuantStore, QuantValue, + QuantizationParametersPrimitive, QuantizedBytes, + }, + tensor::{FloatTensor, IntTensor, QuantizedTensor}, +}; +use burn_std::{FloatDType, IntDType}; + +use crate::{ + FloatNdArrayElement, NdArray, NdArrayDevice, NdArrayQTensor, NdArrayTensor, SharedArray, + element::{IntNdArrayElement, QuantElement}, + execute_with_dtype, execute_with_int_dtype, execute_with_int_out_dtype, + execute_with_numeric_dtype, slice, +}; + +use super::quantization::{QuantizationStrategy, SymmetricQuantization}; +use super::{NdArrayMathOps, NdArrayOps}; + +impl QTensorOps + for NdArray +where + NdArrayTensor: From>, + NdArrayTensor: From>, +{ + fn q_from_data(data: TensorData, _device: &NdArrayDevice) -> QuantizedTensor { + match data.dtype { + DType::QFloat(scheme) => { + let shape = data.shape.clone(); + let num_elements = data.num_elements(); + let q_bytes = QuantizedBytes { + bytes: data.into_bytes(), + scheme, + num_elements, + }; + + match scheme { + QuantScheme { + level: QuantLevel::Tensor | QuantLevel::Block(_), + mode: QuantMode::Symmetric, + value: QuantValue::Q8F | QuantValue::Q8S, + .. + } => { + // We can load QuantStore::U32 w/ QuantizedBytes impl + let (values, qparams) = q_bytes.into_vec_i8(); + let data = TensorData::new(values, shape); + // Overwrite storage + let scheme = scheme.with_store(QuantStore::Native); + + let qparams = qparams + .scales + .into_iter() + .map(|scales| QParams { scales }) + .collect(); + + NdArrayQTensor { + qtensor: NdArrayTensor::from_data(data), + scheme, + qparams, + } + } + QuantScheme { + value: + QuantValue::Q4F + | QuantValue::Q4S + | QuantValue::Q2F + | QuantValue::Q2S + | QuantValue::E2M1 + | QuantValue::E4M3 + | QuantValue::E5M2, + .. + } => unimplemented!("from_data not supported for scheme {scheme:?}"), + } + } + _ => panic!( + "Invalid dtype (expected DType::QFloat, got {:?})", + data.dtype + ), + } + } + + fn quantize( + tensor: FloatTensor, + scheme: &QuantScheme, + qparams: QuantizationParametersPrimitive, + ) -> QuantizedTensor { + let shape = tensor.shape(); + let data_f = tensor.into_data(); + let scales = qparams.scales.into_data().convert::(); + + // Implement with ndarray instead of QuantizationStrategy? + let (data, qparams) = match scheme { + QuantScheme { + level: QuantLevel::Tensor, + mode: QuantMode::Symmetric, + #[cfg(not(feature = "export_tests"))] + value: QuantValue::Q8F | QuantValue::Q8S, + // For tests, "native" sub-byte quant serves as a reference for value equality. + // Values are stored as i8 regardless. + #[cfg(feature = "export_tests")] + value: + QuantValue::Q8F + | QuantValue::Q8S + | QuantValue::Q4F + | QuantValue::Q4S + | QuantValue::Q2F + | QuantValue::Q2S, + store: QuantStore::Native, + .. + } => { + let scales = scales.iter().next().unwrap(); + let strategy = QuantizationStrategy::PerTensorSymmetric( + SymmetricQuantization::init(scales, scheme.value), + ); + let values = strategy.quantize(data_f.as_slice().unwrap()); + ( + TensorData::quantized(values, shape.clone(), *scheme, &[scales]), + vec![QParams { scales }], + ) + } + QuantScheme { + level: QuantLevel::Block(block_size), + mode: QuantMode::Symmetric, + #[cfg(not(feature = "export_tests"))] + value: QuantValue::Q8F | QuantValue::Q8S, + #[cfg(feature = "export_tests")] + value: + QuantValue::Q8F + | QuantValue::Q8S + | QuantValue::Q4F + | QuantValue::Q4S + | QuantValue::Q2F + | QuantValue::Q2S, + store: QuantStore::Native, + .. + } => { + let scales = scales.as_slice().unwrap(); + let (strategy, qparams) = scales + .iter() + .map(|&s| { + ( + SymmetricQuantization::init(s, scheme.value), + QParams { scales: s }, + ) + }) + .unzip(); + let strategy = QuantizationStrategy::PerBlockSymmetric(strategy, *block_size); + let values = strategy.quantize(data_f.as_slice().unwrap()); + ( + TensorData::quantized(values, shape.clone(), *scheme, scales), + qparams, + ) + } + scheme => unimplemented!("Quantization not supported for scheme {scheme:?}"), + }; + + let num_elements = data.num_elements(); + let q_bytes = QuantizedBytes { + bytes: data.into_bytes(), + scheme: *scheme, + num_elements, + }; + let (values, _) = q_bytes.into_vec_i8(); + let data = TensorData::new(values, shape).convert::(); + + NdArrayQTensor { + qtensor: NdArrayTensor::from_data(data), + scheme: *scheme, + qparams, + } + } + + fn dequantize(tensor: QuantizedTensor, dtype: FloatDType) -> FloatTensor { + let strategy = tensor.strategy(); + let scheme = tensor.scheme; + let shape = tensor.shape(); + let data = match tensor.qtensor { + NdArrayTensor::I8(storage) => { + let data = storage.into_shared().into_iter().collect(); + dequantize(data, shape, scheme, &strategy, dtype.into()) + } + _ => unreachable!(), + }; + NdArrayTensor::from_data(data) + } + + fn q_device(_tensor: &QuantizedTensor) -> NdArrayDevice { + NdArrayDevice::Cpu + } + + fn q_to_device( + tensor: QuantizedTensor, + _device: &NdArrayDevice, + ) -> QuantizedTensor { + tensor + } + + fn q_reshape(tensor: QuantizedTensor, shape: Shape) -> QuantizedTensor { + NdArrayQTensor { + qtensor: execute_with_dtype!(tensor.qtensor, E, |array: SharedArray| { + NdArrayOps::reshape(array, shape) + }), + scheme: tensor.scheme, + qparams: tensor.qparams, + } + } + + async fn q_into_data(tensor: QuantizedTensor) -> Result { + let shape = tensor.qtensor.shape(); + let scales = tensor.qparams.iter().map(|q| q.scales).collect::>(); + Ok(execute_with_numeric_dtype!( + tensor.qtensor, + E, + |array: SharedArray| { + let values = array.into_iter().collect(); + TensorData::quantized(values, shape, tensor.scheme, &scales) + } + )) + } + + fn q_swap_dims( + tensor: QuantizedTensor, + dim1: usize, + dim2: usize, + ) -> QuantizedTensor { + NdArrayQTensor { + qtensor: execute_with_dtype!(tensor.qtensor, E, |array: SharedArray| { + NdArrayOps::swap_dims(array, dim1, dim2) + }), + scheme: tensor.scheme, + qparams: tensor.qparams, + } + } + + fn q_permute(tensor: QuantizedTensor, axes: &[usize]) -> QuantizedTensor { + NdArrayQTensor { + qtensor: execute_with_dtype!(tensor.qtensor, E, |array: SharedArray| { + NdArrayOps::permute(array, axes) + }), + scheme: tensor.scheme, + qparams: tensor.qparams, + } + } + + fn q_flip(tensor: QuantizedTensor, axes: &[usize]) -> QuantizedTensor { + NdArrayQTensor { + qtensor: execute_with_dtype!(tensor.qtensor, E, |array: SharedArray| { + NdArrayOps::flip(array, axes) + }), + scheme: tensor.scheme, + qparams: tensor.qparams, + } + } + + fn q_gather( + dim: usize, + tensor: QuantizedTensor, + indices: IntTensor, + ) -> QuantizedTensor { + let qtensor = execute_with_int_dtype!(indices, IntElem, |idx_array: SharedArray< + IntElem, + >| + -> NdArrayTensor { + execute_with_numeric_dtype!(tensor.qtensor, E, |array: SharedArray| { + NdArrayOps::gather(dim, array, idx_array) + }) + }); + NdArrayQTensor { + qtensor, + scheme: tensor.scheme, + qparams: tensor.qparams, + } + } + + fn q_select( + tensor: QuantizedTensor, + dim: usize, + indices: IntTensor, + ) -> QuantizedTensor { + let qtensor = execute_with_int_dtype!(indices, IntElem, |idx_array: SharedArray< + IntElem, + >| + -> NdArrayTensor { + execute_with_numeric_dtype!(tensor.qtensor, E, |array: SharedArray| { + NdArrayMathOps::select(array, dim, idx_array) + }) + }); + NdArrayQTensor { + qtensor, + scheme: tensor.scheme, + qparams: tensor.qparams, + } + } + + fn q_slice( + tensor: QuantizedTensor, + slices: &[burn_backend::Slice], + ) -> QuantizedTensor { + NdArrayQTensor { + qtensor: slice!(tensor.qtensor, slices), + scheme: tensor.scheme, + qparams: tensor.qparams, + } + } + + fn q_argmax(tensor: QuantizedTensor, dim: usize, out_dtype: IntDType) -> IntTensor { + execute_with_int_out_dtype!(out_dtype, I, { + execute_with_numeric_dtype!(tensor.qtensor, E, |array: SharedArray| { + NdArrayMathOps::argmax::(array, dim) + }) + }) + } + + fn q_argmin(tensor: QuantizedTensor, dim: usize, out_dtype: IntDType) -> IntTensor { + execute_with_int_out_dtype!(out_dtype, I, { + execute_with_numeric_dtype!(tensor.qtensor, E, |array: SharedArray| { + NdArrayMathOps::argmin::(array, dim) + }) + }) + } + + fn q_expand(tensor: QuantizedTensor, shape: Shape) -> QuantizedTensor { + NdArrayQTensor { + qtensor: execute_with_dtype!(tensor.qtensor, E, |array: SharedArray| { + NdArrayOps::expand(array, shape) + }), + scheme: tensor.scheme, + qparams: tensor.qparams, + } + } +} + +fn dequantize( + data: Vec, + shape: Shape, + scheme: QuantScheme, + strategy: &QuantizationStrategy, + dtype: DType, +) -> TensorData { + let qparams = match strategy { + QuantizationStrategy::PerTensorSymmetric(quant) => vec![quant.scale], + QuantizationStrategy::PerBlockSymmetric(quant, _block_size) => { + quant.iter().map(|q| q.scale).collect() + } + }; + let q_bytes = QuantizedBytes::new(data, scheme, &qparams); + let (values, _qparams) = q_bytes.into_vec_i8(); + TensorData::new(strategy.dequantize(&values), shape).convert_dtype(dtype) +} diff --git a/crates/burn/src/ops/quantization.rs b/crates/burn/src/ops/quantization.rs new file mode 100644 index 00000000..adaf1b16 --- /dev/null +++ b/crates/burn/src/ops/quantization.rs @@ -0,0 +1,218 @@ +use alloc::vec::Vec; +use num_traits::{Float, PrimInt}; + +use burn_backend::quantization::{BlockSize, QuantValue}; + +// NOTE: this mainly serves as a simple reference implementation. +// The de/quantization ops should be refactored to use ndarray. + +/// Quantization strategy. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum QuantizationStrategy { + /// Per-tensor symmetric quantization. + PerTensorSymmetric(SymmetricQuantization), + /// Per-block symmetric quantization. + PerBlockSymmetric(Vec>, BlockSize), +} + +impl QuantizationStrategy { + /// Quantize the values to a lower precision data type. + pub fn quantize(&self, values: &[f32]) -> Vec { + match self { + QuantizationStrategy::PerTensorSymmetric(strategy) => strategy.quantize(values), + QuantizationStrategy::PerBlockSymmetric(strategy, block_size) => { + let block_elems = block_size.num_elements(); + let num_blocks = strategy.len(); + let numel = values.len(); + assert_eq!( + numel / block_elems, + num_blocks, + "Invalid per-block quantization with num blocks {num_blocks} and {numel} values" + ); + values + .chunks(block_elems) + .enumerate() + .flat_map(|(block_id, block)| strategy[block_id].quantize(block)) + .collect() + } + } + } + + /// Dequantize the values to a higher precision data type. + pub fn dequantize(&self, values: &[i8]) -> Vec { + match self { + QuantizationStrategy::PerTensorSymmetric(strategy) => strategy.dequantize(values), + QuantizationStrategy::PerBlockSymmetric(strategy, block_size) => { + let block_elems = block_size.num_elements(); + let num_blocks = strategy.len(); + let numel = values.len(); + assert_eq!( + numel / block_elems, + num_blocks, + "Invalid per-block quantization with block size {block_elems}, num blocks {num_blocks} and {numel} values" + ); + values + .chunks(block_elems) + .enumerate() + .flat_map(|(block_id, block)| strategy[block_id].dequantize(block)) + .collect() + } + } + } +} + +/// Quantization scheme to convert elements of a higher precision data type `E` to a lower precision +/// data type `Q` and vice-versa. +pub trait Quantization { + /// Returns the quantization range `[a, b]`. + fn range(&self) -> (E, E); + /// Convert the values to a lower precision data type. + fn quantize(&self, values: &[E]) -> Vec; + /// Convert a single value to a lower precision data type. + fn quantize_one(&self, value: E) -> Q; + /// Convert the values back to a higher precision data type. + fn dequantize(&self, values: &[Q]) -> Vec; + /// Convert a single value back to a higher precision data type. + fn dequantize_one(&self, value: Q) -> E; +} + +fn valid_scale(mut scale: E) -> E { + // If scale is 0 (most likely due to a tensor full of zeros), we arbitrarily adjust the + // scale to 0.1 to avoid division by zero. + if scale.eq(&E::zero()) { + scale = E::from(0.1).unwrap(); + } + scale +} + +/// Symmetric quantization scheme. +#[derive(Debug, Clone, Copy)] +pub struct SymmetricQuantization { + /// The scaling factor. + pub scale: E, + // The quantization value data type. + value: QuantValue, +} + +impl SymmetricQuantization { + /// Initialize a symmetric quantization scheme with the given parameters. + pub fn init(scale: E, value: QuantValue) -> Self { + Self { + scale: valid_scale(scale), + value, + } + } + + #[allow(dead_code)] + /// Create a new quantization scheme for an input range `[alpha, beta]`. + fn new(alpha: E, beta: E, value: QuantValue) -> Self { + let (a, b) = value.range(); + let a = E::from(a).unwrap(); + let b = E::from(b).unwrap(); + + // Compute scale to convert a floating point value in range `[-alpha, alpha]` to the quantized range + let alpha = alpha.abs().max(beta.abs()); + let scale = valid_scale((alpha + alpha) / (b - a)); + Self { scale, value } + } +} + +impl Quantization for SymmetricQuantization { + fn quantize(&self, values: &[E]) -> Vec { + values.iter().map(|x| self.quantize_one(*x)).collect() + } + + fn dequantize(&self, values: &[Q]) -> Vec { + values.iter().map(|x_q| self.dequantize_one(*x_q)).collect() + } + + fn quantize_one(&self, value: E) -> Q { + let (a, b) = self.range(); + + // x_q = clamp(round(x / scale), a, b) + Q::from(value.div(self.scale).round().clamp(a, b)).unwrap() + } + + fn dequantize_one(&self, value: Q) -> E { + // x = scale * x_q + self.scale * E::from(value).unwrap() + } + + fn range(&self) -> (E, E) { + let (a, b) = self.value.range(); + let a = E::from(a).unwrap(); + let b = E::from(b).unwrap(); + (a, b) + } +} + +impl PartialEq for SymmetricQuantization { + fn eq(&self, other: &Self) -> bool { + self.scale == other.scale + } +} + +impl Eq for SymmetricQuantization {} + +#[cfg(test)] +mod tests { + use burn_backend::TensorData; + + use super::*; + use alloc::vec; + + #[test] + fn test_int8_symmetric_quantization() { + let x: [f32; 4] = [-1.8, -1.0, 0.0, 0.5]; + let expected_q = vec![-127, -71, 0, 35]; + let expected_d = vec![-1.8, -1.0062993, 0.0, 0.496063]; + + let symmetric = SymmetricQuantization::::new(-1.8, 0.5, QuantValue::Q8S); + + let q: Vec = symmetric.quantize(&x); + assert_eq!(q, expected_q); + + let d = symmetric.dequantize(&expected_q); + + assert_eq!(d, expected_d); + } + + #[test] + fn test_int8_symmetric_quantization_per_block() { + let x: [f32; 8] = [-1.8, -1.0, 0.0, 0.5, -1.8, -1.0, 0.0, 0.5]; + let expected_q = vec![-127, -71, 0, 35, -127, -71, 0, 35]; + let expected_d = vec![ + -1.8, -1.0062993, 0.0, 0.496063, -1.8, -1.0062993, 0.0, 0.496063, + ]; + + let symmetric = SymmetricQuantization::::new(-1.8, 0.5, QuantValue::Q8S); + let strategy = QuantizationStrategy::PerBlockSymmetric( + vec![symmetric, symmetric], + BlockSize::new([4]), + ); + + let q: Vec = strategy.quantize(&x); + assert_eq!(q, expected_q); + + let d = symmetric.dequantize(&expected_q); + + assert_eq!(d, expected_d); + } + + #[test] + fn should_support_dequantize() { + let strategy = QuantizationStrategy::PerTensorSymmetric(SymmetricQuantization { + scale: 0.1, + value: QuantValue::Q8S, + }); + + let output = strategy.dequantize(&[-127i8, -77, -26, 25, 76, 127]); + + let output = TensorData::new(output, [2, 3]); + + output.assert_approx_eq::( + &TensorData::from([[-12.7, -7.7, -2.6], [2.5, 7.6, 12.7]]), + Default::default(), + ); + } +} diff --git a/crates/burn/src/ops/simd/avgpool.rs b/crates/burn/src/ops/simd/avgpool.rs new file mode 100644 index 00000000..41d5ba61 --- /dev/null +++ b/crates/burn/src/ops/simd/avgpool.rs @@ -0,0 +1,443 @@ +use core::{marker::PhantomData, mem::transmute}; + +use crate::{SharedArray, iter_range_par, run_par, sharing::UnsafeSharedRef}; + +use burn_backend::DType; +use burn_backend::{Element, ElementConversion}; +use bytemuck::Zeroable; +use macerator::{Simd, VAdd, VDiv}; +use ndarray::{Array4, s}; +use nhwc::avg_pool_nhwc; + +use super::should_use_simd; + +#[macerator::with_simd] +fn is_accelerated(_x: PhantomData) -> bool { + ::is_accelerated::() && ::is_accelerated::() +} + +pub(crate) fn try_avg_pool2d_simd( + x: SharedArray, + ksize: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + with_pad: bool, +) -> Result, SharedArray> { + // Strides must be unit, dilation isn't supported, rows must be contiguous + if x.strides()[1] != 1 || !should_use_simd(x.shape()[1]) { + return Err(x); + } + + match E::dtype() { + DType::F64 if is_accelerated::(PhantomData) => Ok(cast(avg_pool_nhwc::( + cast(x), + ksize, + stride, + padding, + with_pad, + ))), + DType::F32 if is_accelerated::(PhantomData) => Ok(cast(avg_pool_nhwc::( + cast(x), + ksize, + stride, + padding, + with_pad, + ))), + _ => Err(x), + } +} + +fn cast(tensor: SharedArray) -> SharedArray { + unsafe { transmute::, SharedArray>(tensor) } +} + +mod nhwc { + use itertools::Itertools; + use macerator::{Simd, Vector, vload_unaligned, vstore_unaligned}; + use ndarray::{ArrayView3, ArrayViewMut3}; + use seq_macro::seq; + + use crate::ops::simd::lanes; + + use super::*; + + // Until you can use associated constants as array size, we need to hardcode this. + // The most common config (x86-v3) has 16 registers, so use half of them for accumulators. + const BLOCK_REGISTERS: usize = 8; + + pub(crate) fn avg_pool_nhwc( + x: SharedArray, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + with_pad: bool, + ) -> SharedArray { + let [kernel_height, kernel_width] = kernel_size; + let [pad_h, pad_w] = padding; + let [stride_height, stride_width] = stride; + let [batch_size, channels, x_height, x_width] = x.shape().try_into().unwrap(); + let lanes = lanes::(); + + let ch_block = lanes * BLOCK_REGISTERS; + + let out_height = ((x_height + 2 * pad_h - (kernel_height - 1) - 1) / stride_height) + 1; + let out_width = ((x_width + 2 * pad_w - (kernel_width - 1) - 1) / stride_width) + 1; + + let mut output = unsafe { + Array4::::uninit((batch_size, out_height, out_width, channels)).assume_init() + }; + let unsafe_shared_out = UnsafeSharedRef::new(&mut output); + let x = x.view(); + let x = x.permuted_axes(vec![0, 2, 3, 1]); + + // Floor division ensures `blocks * lanes * blocking factor` is always `<= out_channels`. + // An exclusive loop will always have `lanes * blocking factor` elements in bounds. + let blocks = channels / ch_block; + let blocks_end = blocks * ch_block; + // Floor division means simd_end is always divisible by `lanes` and `<= out_channels`. An + // exclusive loop will always have `lanes` elements in bounds. + let simd_end = channels / lanes * lanes; + let num_simd_unblocked = (simd_end - blocks_end) / lanes; + let remainder = channels - simd_end; + + run_par!(|| { + // SAFETY: Loop ranges are non-overlapping, so the unsafe shared reference is safe. + iter_range_par!(0, batch_size * blocks).for_each(|k| unsafe { + let block = k % blocks; + let b = k / blocks; + + let output = unsafe_shared_out.get(); + + let x = x.slice(s![b, .., .., ..]); + let out = output.slice_mut(s![b, .., .., ..]); + + loop_blocked(x, out, kernel_size, stride, padding, with_pad, block); + }); + // SAFETY: See `loop_unblocked` + iter_range_par!(0, batch_size * num_simd_unblocked).for_each(|k| unsafe { + let ch = (k % num_simd_unblocked) * lanes + blocks_end; + let b = k / num_simd_unblocked; + + let output = unsafe_shared_out.get(); + + let x = x.slice(s![b, .., .., ..]); + let out = output.slice_mut(s![b, .., .., ..]); + + loop_unblocked(x, out, kernel_size, stride, padding, with_pad, ch); + }); + // SAFETY: Loop ranges are non-overlapping, so the unsafe shared reference is safe. + iter_range_par!(0, batch_size * remainder).for_each(|k| unsafe { + let ch = (k % remainder) + simd_end; + let b = k / remainder; + + let output = unsafe_shared_out.get(); + + let x = x.slice(s![b, .., .., ..]); + let out = output.slice_mut(s![b, .., .., ..]); + + loop_scalar(x, out, kernel_size, stride, padding, with_pad, ch); + }); + }); + + output = output.permuted_axes([0, 3, 1, 2]); + + output.into_dyn().into_shared() + } + + /// Execute the blocked (unrolled) portion of the pool. + #[allow( + clippy::too_many_arguments, + clippy::erasing_op, + clippy::identity_op, + unused_mut + )] + #[macerator::with_simd] + fn loop_blocked<'a, S: Simd, E: Element + VAdd + VDiv>( + x: ArrayView3<'a, E>, + mut out: ArrayViewMut3<'a, E>, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + with_pad: bool, + block: usize, + ) where + 'a: 'a, + { + let [kernel_height, kernel_width] = kernel_size; + let [pad_h, pad_w] = padding; + let [stride_height, stride_width] = stride; + + let (x_height, x_width, _) = x.dim(); + let (out_height, out_width, _) = out.dim(); + let lanes = E::lanes::(); + + let ch_block = lanes * BLOCK_REGISTERS; + + // If pixels are more than `padding` from the edges, the in pixel cannot be out of bounds + for oh in pad_h..out_height.saturating_sub(pad_h) { + for ow in pad_w..out_width.saturating_sub(pad_w) { + seq!(N in 0..8 { + let mut sum~N: Vector = Zeroable::zeroed(); + }); + let ch = block * ch_block; + let ch_end = ch + ch_block; + let mut out = out.slice_mut(s![oh, ow, ch..ch_end]); + + for kh in 0..kernel_height { + let ih = oh * stride_height + kh - pad_h; + + for kw in 0..kernel_width { + let iw = ow * stride_width + kw - pad_w; + let x = x.slice(s![ih, iw, ch..ch_end]); + + seq!(N in 0..8 { + // SAFETY: + // Load a full vector from x[N * lanes]. This is bounds checked by the + // slice above. + sum~N += unsafe { vload_unaligned(&x[N * lanes]) }; + }); + } + } + + let count = kernel_height * kernel_width; + let count = (count as u64).elem::(); + let count_v = count.splat(); + seq!(N in 0..8 { + let s~N = sum~N / count_v; + // SAFETY: + // Store a full vector to out[N * lanes]. This is bounds checked by the + // slice above. + unsafe { vstore_unaligned(&mut out[N * lanes], s~N) }; + }); + } + } + + // Border pixels need bounds checks + if (pad_h, pad_w) != (0, 0) { + let v_borders = (0..pad_h) + .chain(out_height.saturating_sub(pad_h)..out_height) + .cartesian_product(0..out_width); + let h_borders = (0..out_height) + .cartesian_product((0..pad_w).chain(out_width.saturating_sub(pad_w)..out_width)); + + for (oh, ow) in v_borders.chain(h_borders) { + seq!(N in 0..8 { + let mut sum~N: Vector = Zeroable::zeroed(); + }); + let mut count: usize = 0; + let ch = block * ch_block; + let ch_end = ch + ch_block; + let mut out = out.slice_mut(s![oh, ow, ch..ch_end]); + + for kh in 0..kernel_height { + let ih = oh * stride_height + kh; + if ih < pad_h || ih >= x_height + pad_h { + continue; + } + let ih = ih - pad_h; + + for kw in 0..kernel_width { + let iw = ow * stride_width + kw; + if iw < pad_w || iw >= x_width + pad_w { + continue; + } + let iw = iw - pad_w; + count += 1; + + let x = x.slice(s![ih, iw, ch..ch_end]); + + seq!(N in 0..8 { + // SAFETY: + // Load a full vector from x[N * lanes]. This is bounds checked by the + // slice above. + sum~N += unsafe { vload_unaligned(&x[N * lanes]) }; + }); + } + } + + if with_pad { + count = kernel_height * kernel_width; + } + + let count = (count as u64).elem::(); + let count_v = count.splat(); + seq!(N in 0..8 { + let s~N = sum~N / count_v; + // SAFETY: + // Store a full vector to out[N * lanes]. This is bounds checked by the + // slice above. + unsafe { vstore_unaligned(&mut out[N * lanes], s~N) }; + }); + } + } + } + + /// Execute the unblocked (not unrolled) portion of the pool. + /// + /// SAFETY: Safe as long as `ch + simd_lanes <= out_channels`. + #[allow(clippy::too_many_arguments, unused_mut)] + #[macerator::with_simd] + unsafe fn loop_unblocked<'a, S: Simd, E: Element + VAdd + VDiv>( + x: ArrayView3<'a, E>, + mut out: ArrayViewMut3<'a, E>, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + with_pad: bool, + ch: usize, + ) where + 'a: 'a, + { + let [kernel_height, kernel_width] = kernel_size; + let [pad_h, pad_w] = padding; + let [stride_height, stride_width] = stride; + + let (x_height, x_width, _) = x.dim(); + let (out_height, out_width, _) = out.dim(); + + // If pixels are not within padding range, bounds checks are always true + for oh in pad_h..out_height - pad_h { + for ow in pad_w..out_width - pad_w { + let mut sum: Vector = Zeroable::zeroed(); + + for kh in 0..kernel_height { + let ih = oh * stride_height + kh - pad_h; + + for kw in 0..kernel_width { + let iw = ow * stride_width + kw - pad_w; + // Load a full vector from `x`. In bounds as long as `out_channels >= ch + lanes` + let s0 = unsafe { vload_unaligned(&x[[ih, iw, ch]]) }; + sum += s0; + } + } + + let count = kernel_height * kernel_width; + let count: E = (count as u64).elem(); + let count_v = count.splat(); + let s0 = sum / count_v; + // Store a full vector to `out`. In bounds as long as `out_channels >= ch + lanes`. + unsafe { vstore_unaligned(&mut out[[oh, ow, ch]], s0) }; + } + } + + // Border pixels need bounds checks + if (pad_h, pad_w) != (0, 0) { + let v_borders = (0..pad_h) + .chain(out_height.saturating_sub(pad_h)..out_height) + .cartesian_product(0..out_width); + let h_borders = (0..out_height) + .cartesian_product((0..pad_w).chain(out_width.saturating_sub(pad_w)..out_width)); + + for (oh, ow) in v_borders.chain(h_borders) { + let mut sum: Vector = Zeroable::zeroed(); + let mut count: usize = 0; + + for kh in 0..kernel_height { + let ih = oh * stride_height + kh; + if ih < pad_h || ih >= x_height + pad_h { + continue; + } + let ih = ih - pad_h; + + for kw in 0..kernel_width { + let iw = ow * stride_width + kw; + if iw < pad_w || iw >= x_width + pad_w { + continue; + } + let iw = iw - pad_w; + count += 1; + + // Load a full vector from `x`. In bounds as long as `out_channels >= ch + lanes` + sum += unsafe { vload_unaligned(&x[[ih, iw, ch]]) }; + } + } + + if with_pad { + count = kernel_height * kernel_width; + } + + let count = (count as u64).elem::(); + let count_v = count.splat(); + let s0 = sum / count_v; + // Store a full vector to `out`. In bounds as long as `out_channels >= ch + lanes`. + unsafe { vstore_unaligned(&mut out[[oh, ow, ch]], s0) }; + } + } + } + + /// Execute scalar portion of the pooling + #[allow(clippy::too_many_arguments)] + fn loop_scalar( + x: ArrayView3<'_, E>, + mut out: ArrayViewMut3<'_, E>, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + with_pad: bool, + ch: usize, + ) { + let [kernel_height, kernel_width] = kernel_size; + let [pad_h, pad_w] = padding; + let [stride_height, stride_width] = stride; + + let (x_height, x_width, _) = x.dim(); + let (out_height, out_width, _) = out.dim(); + + // If pixels are not within padding range, bounds checks are always true + for oh in pad_h..out_height.saturating_sub(pad_h) { + for ow in pad_w..out_width.saturating_sub(pad_w) { + let mut sum: E = Zeroable::zeroed(); + + for kh in 0..kernel_height { + let ih = oh * stride_height + kh - pad_h; + + for kw in 0..kernel_width { + let iw = ow * stride_width + kw - pad_w; + sum = sum + x[[ih, iw, ch]]; + } + } + + let count = (kernel_height * kernel_width) as u64; + out[[oh, ow, ch]] = sum / count.elem(); + } + } + + // Border pixels need bounds checks + if (pad_h, pad_w) != (0, 0) { + let v_borders = (0..pad_h) + .chain(out_height.saturating_sub(pad_h)..out_height) + .cartesian_product(0..out_width); + let h_borders = (0..out_height) + .cartesian_product((0..pad_w).chain(out_width.saturating_sub(pad_w)..out_width)); + + for (oh, ow) in v_borders.chain(h_borders) { + let mut sum: E = Zeroable::zeroed(); + let mut count: usize = 0; + + for kh in 0..kernel_height { + let ih = oh * stride_height + kh; + if ih < pad_h || ih >= x_height + pad_h { + continue; + } + let ih = ih - pad_h; + + for kw in 0..kernel_width { + let iw = ow * stride_width + kw; + if iw < pad_w || iw >= x_width + pad_w { + continue; + } + let iw = iw - pad_w; + count += 1; + sum = sum + x[[ih, iw, ch]]; + } + } + + if with_pad { + count = kernel_height * kernel_width; + } + + out[[oh, ow, ch]] = sum / (count as u64).elem(); + } + } + } +} diff --git a/crates/burn/src/ops/simd/base.rs b/crates/burn/src/ops/simd/base.rs new file mode 100644 index 00000000..005316f7 --- /dev/null +++ b/crates/burn/src/ops/simd/base.rs @@ -0,0 +1,115 @@ +use core::{marker::PhantomData, mem::MaybeUninit}; + +use macerator::{Arch, Scalar, Simd}; +use ndarray::{ArcArray, ArrayD, IxDyn, ShapeBuilder}; + +/// Whether SIMD instructions are worth using +#[cfg(all( + any( + target_arch = "x86", + target_arch = "x86_64", + target_arch = "aarch64", + target_arch = "wasm32", + target_arch = "loongarch64" + ), + not(test) +))] +pub fn should_use_simd(len: usize) -> bool { + len >= 32 +} + +/// Whether SIMD instructions are worth using +#[cfg(all( + not(any( + target_arch = "x86", + target_arch = "x86_64", + target_arch = "aarch64", + target_arch = "wasm32", + target_arch = "loongarch64" + )), + not(test) +))] +pub fn should_use_simd(_len: usize) -> bool { + false +} + +#[cfg(test)] +pub fn should_use_simd(_len: usize) -> bool { + true +} + +pub(crate) fn lanes() -> usize { + #[allow(non_camel_case_types)] + struct lanes<__T0>(__T0); + + impl ::macerator::WithSimd for lanes> { + type Output = usize; + #[inline(always)] + fn with_simd<__S: ::macerator::Simd>(self) -> ::Output { + let Self(__ty) = self; + #[allow(unused_unsafe)] + unsafe { + lanes_simd::<__S, E>(__ty) + } + } + } + (Arch::new()).dispatch(lanes(PhantomData::)) +} + +fn lanes_simd(_ty: PhantomData) -> usize { + E::lanes::() +} + +pub(crate) fn uninit_array_like(reference: &ArcArray) -> ArrayD { + let shape = reference.raw_dim(); + let strides = reference.strides(); + let strides = strides.iter().map(|it| *it as usize).collect::>(); + let shape_strides = shape.strides(IxDyn(&strides)); + let size = reference.len(); + let mut out_data: Vec> = Vec::with_capacity(size); + unsafe { out_data.set_len(size) }; + unsafe { ArrayD::from_shape_vec_unchecked(shape_strides, out_data).assume_init() } +} + +pub trait MinMax { + fn min(self, other: Self) -> Self; + fn max(self, other: Self) -> Self; +} + +macro_rules! impl_minmax { + ($ty: ty) => { + impl MinMax for $ty { + fn min(self, other: Self) -> Self { + Ord::min(self, other) + } + fn max(self, other: Self) -> Self { + Ord::max(self, other) + } + } + }; + ($($ty: ty),*) => { + $(impl_minmax!($ty);)* + } +} + +impl_minmax!(u8, i8, u16, i16, u32, i32, u64, i64); + +impl MinMax for f32 { + fn min(self, other: Self) -> Self { + self.min(other) + } + + fn max(self, other: Self) -> Self { + self.max(other) + } +} + +impl MinMax for f64 { + fn min(self, other: Self) -> Self { + self.min(other) + } + + fn max(self, other: Self) -> Self { + self.max(other) + } +} diff --git a/crates/burn/src/ops/simd/binary.rs b/crates/burn/src/ops/simd/binary.rs new file mode 100644 index 00000000..dae3ed57 --- /dev/null +++ b/crates/burn/src/ops/simd/binary.rs @@ -0,0 +1,299 @@ +use core::{marker::PhantomData, slice}; + +use burn_backend::Element; +use macerator::{ + Scalar, Simd, VAdd, VBitAnd, VBitOr, VBitXor, VDiv, VMul, VOrd, VSub, Vector, vload_unaligned, + vstore_unaligned, +}; +use ndarray::ArrayD; +use seq_macro::seq; + +use crate::{NdArrayElement, SharedArray, ops::simd::uninit_array_like}; + +use super::{ + MinMax, + binary_elemwise::{ + VecAdd, VecBitAnd, VecBitOr, VecBitXor, VecDiv, VecMax, VecMin, VecMul, VecSub, + }, + should_use_simd, +}; + +pub trait SimdBinop { + fn apply_vec(lhs: Vector, rhs: Vector) -> Vector; + fn apply(lhs: T, rhs: T) -> Out; + fn is_accelerated() -> bool; +} + +impl SimdBinop for VecAdd { + fn apply_vec(lhs: Vector, rhs: Vector) -> Vector { + lhs + rhs + } + + fn apply(lhs: T, rhs: T) -> T { + lhs + rhs + } + + fn is_accelerated() -> bool { + ::is_accelerated::() + } +} + +impl SimdBinop for VecDiv { + fn apply_vec(lhs: Vector, rhs: Vector) -> Vector { + lhs / rhs + } + + fn apply(lhs: T, rhs: T) -> T { + lhs / rhs + } + + fn is_accelerated() -> bool { + ::is_accelerated::() + } +} + +impl SimdBinop for VecMul { + fn apply_vec(lhs: Vector, rhs: Vector) -> Vector { + lhs * rhs + } + + fn apply(lhs: T, rhs: T) -> T { + lhs * rhs + } + + fn is_accelerated() -> bool { + ::is_accelerated::() + } +} + +impl SimdBinop for VecSub { + fn apply_vec(lhs: Vector, rhs: Vector) -> Vector { + lhs - rhs + } + + fn apply(lhs: T, rhs: T) -> T { + lhs - rhs + } + + fn is_accelerated() -> bool { + ::is_accelerated::() + } +} + +impl SimdBinop for VecMin { + fn apply_vec(lhs: Vector, rhs: Vector) -> Vector { + lhs.min(rhs) + } + + fn apply(lhs: T, rhs: T) -> T { + MinMax::min(lhs, rhs) + } + + fn is_accelerated() -> bool { + ::is_min_max_accelerated::() + } +} + +impl SimdBinop for VecMax { + fn apply_vec(lhs: Vector, rhs: Vector) -> Vector { + lhs.max(rhs) + } + + fn apply(lhs: T, rhs: T) -> T { + MinMax::max(lhs, rhs) + } + + fn is_accelerated() -> bool { + ::is_min_max_accelerated::() + } +} + +impl SimdBinop for VecBitAnd { + fn apply_vec(lhs: Vector, rhs: Vector) -> Vector { + lhs & rhs + } + + fn apply(lhs: T, rhs: T) -> T { + lhs.bitand(rhs) + } + + fn is_accelerated() -> bool { + ::is_accelerated::() + } +} + +impl SimdBinop for VecBitOr { + fn apply_vec(lhs: Vector, rhs: Vector) -> Vector { + lhs | rhs + } + + fn apply(lhs: T, rhs: T) -> T { + lhs.bitor(rhs) + } + + fn is_accelerated() -> bool { + ::is_accelerated::() + } +} + +impl SimdBinop for VecBitXor { + fn apply_vec(lhs: Vector, rhs: Vector) -> Vector { + lhs ^ rhs + } + + fn apply(lhs: T, rhs: T) -> T { + lhs.bitxor(rhs) + } + + fn is_accelerated() -> bool { + ::is_accelerated::() + } +} + +#[macerator::with_simd] +fn is_accelerated>( + _x: PhantomData<(T, Out, Op)>, +) -> bool { + Op::is_accelerated::() +} + +#[allow(clippy::result_large_err)] +pub fn try_binary_simd< + E: Element, + EOut: Element, + T: NdArrayElement + Scalar, + Out: NdArrayElement + Scalar, + Op: SimdBinop, +>( + lhs: SharedArray, + rhs: SharedArray, +) -> Result, (SharedArray, SharedArray)> { + let lhs_len = lhs.len(); + let rhs_len = rhs.len(); + if !should_use_simd(lhs_len.max(rhs_len)) + || !lhs.is_standard_layout() + || !rhs.is_standard_layout() + || lhs.shape() != rhs.shape() + || !is_accelerated::(PhantomData) + { + return Err((lhs, rhs)); + } + // Used to assert traits based on the dynamic `DType`. + let lhs = unsafe { core::mem::transmute::, SharedArray>(lhs) }; + let rhs = unsafe { core::mem::transmute::, SharedArray>(rhs) }; + let out = binary_simd_same::(lhs, rhs); + + // Used to assert traits based on the dynamic `DType`. + let out = unsafe { core::mem::transmute::, SharedArray>(out) }; + Ok(out) +} + +fn binary_simd_same< + T: NdArrayElement + Scalar, + Out: NdArrayElement + Scalar, + Op: SimdBinop, +>( + lhs: SharedArray, + rhs: SharedArray, +) -> SharedArray { + let out = if lhs.is_unique() { + let mut buf = lhs.into_owned(); + let lhs = buf.as_slice_mut().unwrap(); + let rhs = rhs.as_slice().unwrap(); + let out = + unsafe { core::mem::transmute::<&mut [T], &mut [Out]>(unsafe_alias_slice_mut(lhs)) }; + binary(lhs, rhs, out, PhantomData::); + unsafe { core::mem::transmute::, ArrayD>(buf) } + } else if rhs.is_unique() { + let mut buf = rhs.into_owned(); + let lhs = lhs.as_slice().unwrap(); + let rhs = buf.as_slice_mut().unwrap(); + let out = + unsafe { core::mem::transmute::<&mut [T], &mut [Out]>(unsafe_alias_slice_mut(rhs)) }; + binary(lhs, rhs, out, PhantomData::); + unsafe { core::mem::transmute::, ArrayD>(buf) } + } else { + let mut out = uninit_array_like(&lhs); + let lhs = lhs.as_slice().unwrap(); + let rhs = rhs.as_slice().unwrap(); + let out_slice = out.as_slice_mut().unwrap(); + binary(lhs, rhs, out_slice, PhantomData::); + out + }; + out.into_shared() +} + +#[allow(clippy::erasing_op, clippy::identity_op)] +#[macerator::with_simd] +fn binary< + 'a, + S: Simd, + T: NdArrayElement + Scalar, + Out: NdArrayElement + Scalar, + Op: SimdBinop, +>( + lhs: &'a [T], + rhs: &'a [T], + out: &'a mut [Out], + _op: PhantomData, +) where + 'a: 'a, +{ + let lanes = T::lanes::(); + let mut chunks_lhs = lhs.chunks_exact(8 * lanes); + let mut chunks_rhs = rhs.chunks_exact(8 * lanes); + let mut chunks_out = out.chunks_exact_mut(8 * lanes); + while let Some(((lhs, rhs), out)) = chunks_lhs + .next() + .zip(chunks_rhs.next()) + .zip(chunks_out.next()) + { + seq!(N in 0..8 { + // Load one full vector from `lhs`. + // SAFETY: Guaranteed to be in bounds because `len == 8 * lanes` + let lhs~N = unsafe { vload_unaligned::(&lhs[N * lanes]) }; + // Load one full vector from `rhs`. + // SAFETY: Guaranteed to be in bounds because `len == 8 * lanes` + let rhs~N = unsafe { vload_unaligned(&rhs[N * lanes]) }; + let s~N = Op::apply_vec(lhs~N, rhs~N); + // Store one full vector to `out`. + // SAFETY: Guaranteed to be in bounds because `len == 8 * lanes` + unsafe { vstore_unaligned(&mut out[N * lanes], s~N) }; + }); + } + let mut chunks_lhs = chunks_lhs.remainder().chunks_exact(lanes); + let mut chunks_rhs = chunks_rhs.remainder().chunks_exact(lanes); + let mut chunks_out = chunks_out.into_remainder().chunks_exact_mut(lanes); + while let Some(((lhs, rhs), out)) = chunks_lhs + .next() + .zip(chunks_rhs.next()) + .zip(chunks_out.next()) + { + // Load one full vector from `lhs`. + // SAFETY: Guaranteed to be in bounds because `len == lanes` + let lhs0 = unsafe { vload_unaligned::(lhs.as_ptr()) }; + // Load one full vector from `rhs`. + // SAFETY: Guaranteed to be in bounds because `len == lanes` + let rhs0 = unsafe { vload_unaligned(rhs.as_ptr()) }; + let s0 = Op::apply_vec(lhs0, rhs0); + // Store one full vector to `out`. + // SAFETY: Guaranteed to be in bounds because `len == lanes` + unsafe { vstore_unaligned(out.as_mut_ptr(), s0) }; + } + + for ((lhs, rhs), out) in chunks_lhs + .remainder() + .iter() + .zip(chunks_rhs.remainder()) + .zip(chunks_out.into_remainder()) + { + *out = Op::apply(*lhs, *rhs) + } +} + +/// Unsafely alias a slice to use as an inline argument +fn unsafe_alias_slice_mut<'a, T>(slice: &mut [T]) -> &'a mut [T] { + let ptr = slice.as_mut_ptr(); + let len = slice.len(); + unsafe { slice::from_raw_parts_mut(ptr, len) } +} diff --git a/crates/burn/src/ops/simd/binary_elemwise.rs b/crates/burn/src/ops/simd/binary_elemwise.rs new file mode 100644 index 00000000..7534da53 --- /dev/null +++ b/crates/burn/src/ops/simd/binary_elemwise.rs @@ -0,0 +1,419 @@ +use core::marker::PhantomData; + +use bytemuck::cast; +use macerator::{ + Scalar, Simd, VAdd, VBitAnd, VBitOr, VBitXor, VDiv, VMul, VOrd, VSub, Vector, vload, + vload_unaligned, vstore, vstore_unaligned, +}; +use ndarray::ArrayD; +use seq_macro::seq; + +use crate::{NdArrayElement, SharedArray, ops::simd::uninit_array_like}; + +use super::{MinMax, should_use_simd}; + +pub trait ScalarSimdBinop { + type Rhs: Copy; + type RhsVec: Copy; + fn splat(rhs: Self::Rhs) -> Self::RhsVec; + fn apply_vec(lhs: Vector, rhs: Self::RhsVec) -> Vector; + fn apply(lhs: T, rhs: Self::Rhs) -> Out; + fn is_accelerated() -> bool; +} + +pub struct VecAdd; +pub struct VecDiv; +pub struct VecMul; +pub struct VecSub; +pub struct VecMin; +pub struct VecMax; +pub struct VecClamp; +pub struct VecBitAnd; +pub struct VecBitOr; +pub struct VecBitXor; + +impl ScalarSimdBinop for VecAdd { + type Rhs = T; + type RhsVec = Vector; + + fn splat(rhs: Self::Rhs) -> Self::RhsVec { + rhs.splat() + } + + fn apply_vec(lhs: Vector, rhs: Self::RhsVec) -> Vector { + lhs + rhs + } + + fn apply(lhs: T, rhs: T) -> T { + lhs + rhs + } + + fn is_accelerated() -> bool { + ::is_accelerated::() + } +} + +impl ScalarSimdBinop for VecDiv { + type Rhs = T; + type RhsVec = Vector; + + fn splat(rhs: Self::Rhs) -> Self::RhsVec { + rhs.splat() + } + + fn apply_vec(lhs: Vector, rhs: Self::RhsVec) -> Vector { + lhs / rhs + } + + fn apply(lhs: T, rhs: T) -> T { + lhs / rhs + } + + fn is_accelerated() -> bool { + ::is_accelerated::() + } +} + +impl ScalarSimdBinop for VecMul { + type Rhs = T; + type RhsVec = Vector; + + fn splat(rhs: Self::Rhs) -> Self::RhsVec { + rhs.splat() + } + + fn apply_vec(lhs: Vector, rhs: Self::RhsVec) -> Vector { + lhs * rhs + } + + fn apply(lhs: T, rhs: T) -> T { + lhs * rhs + } + + fn is_accelerated() -> bool { + ::is_accelerated::() + } +} + +impl ScalarSimdBinop for VecSub { + type Rhs = T; + type RhsVec = Vector; + + fn splat(rhs: Self::Rhs) -> Self::RhsVec { + rhs.splat() + } + + fn apply_vec(lhs: Vector, rhs: Self::RhsVec) -> Vector { + lhs - rhs + } + + fn apply(lhs: T, rhs: T) -> T { + lhs - rhs + } + + fn is_accelerated() -> bool { + ::is_accelerated::() + } +} + +impl ScalarSimdBinop for VecMin { + type Rhs = T; + type RhsVec = Vector; + + fn splat(rhs: Self::Rhs) -> Self::RhsVec { + rhs.splat() + } + + fn apply_vec(lhs: Vector, rhs: Self::RhsVec) -> Vector { + lhs.min(rhs) + } + + fn apply(lhs: T, rhs: T) -> T { + lhs.min(rhs) + } + + fn is_accelerated() -> bool { + ::is_min_max_accelerated::() + } +} + +impl ScalarSimdBinop for VecMax { + type Rhs = T; + type RhsVec = Vector; + + fn splat(rhs: Self::Rhs) -> Self::RhsVec { + rhs.splat() + } + + fn apply_vec(lhs: Vector, rhs: Self::RhsVec) -> Vector { + lhs.max(rhs) + } + + fn apply(lhs: T, rhs: T) -> T { + lhs.max(rhs) + } + + fn is_accelerated() -> bool { + ::is_min_max_accelerated::() + } +} + +impl ScalarSimdBinop for VecClamp { + type Rhs = (T, T); + type RhsVec = (Vector, Vector); + + fn splat((min, max): Self::Rhs) -> Self::RhsVec { + (min.splat(), max.splat()) + } + + fn apply_vec(lhs: Vector, (min, max): Self::RhsVec) -> Vector { + lhs.min(max).max(min) + } + + fn apply(lhs: T, (min, max): Self::Rhs) -> T { + lhs.min(max).max(min) + } + + fn is_accelerated() -> bool { + ::is_min_max_accelerated::() + } +} + +impl ScalarSimdBinop for VecBitAnd { + type Rhs = T; + type RhsVec = Vector; + + fn splat(rhs: Self::Rhs) -> Self::RhsVec { + rhs.splat() + } + + fn apply_vec(lhs: Vector, rhs: Self::RhsVec) -> Vector { + lhs & rhs + } + + fn apply(lhs: T, rhs: Self::Rhs) -> T { + lhs & rhs + } + + fn is_accelerated() -> bool { + ::is_accelerated::() + } +} + +impl ScalarSimdBinop for VecBitOr { + type Rhs = T; + type RhsVec = Vector; + + fn splat(rhs: Self::Rhs) -> Self::RhsVec { + rhs.splat() + } + + fn apply_vec(lhs: Vector, rhs: Self::RhsVec) -> Vector { + lhs | rhs + } + + fn apply(lhs: T, rhs: Self::Rhs) -> T { + lhs | rhs + } + + fn is_accelerated() -> bool { + ::is_accelerated::() + } +} + +impl ScalarSimdBinop for VecBitXor { + type Rhs = T; + type RhsVec = Vector; + + fn splat(rhs: Self::Rhs) -> Self::RhsVec { + rhs.splat() + } + + fn apply_vec(lhs: Vector, rhs: Self::RhsVec) -> Vector { + lhs ^ rhs + } + + fn apply(lhs: T, rhs: Self::Rhs) -> T { + lhs ^ rhs + } + + fn is_accelerated() -> bool { + ::is_accelerated::() + } +} + +#[macerator::with_simd] +fn is_accelerated>( + _x: PhantomData<(T, Out, Op)>, +) -> bool { + Op::is_accelerated::() +} + +pub fn try_binary_scalar_simd< + E: NdArrayElement, + EOut: NdArrayElement, + T: NdArrayElement + Scalar, + Out: NdArrayElement + Scalar, + Op: ScalarSimdBinop, +>( + input: SharedArray, + elem: Op::Rhs, +) -> Result, SharedArray> { + if !should_use_simd(input.len()) + || input.as_slice_memory_order().is_none() + || !is_accelerated::(PhantomData) + { + return Err(input); + } + // Used to assert traits based on the dynamic `DType`. + let input = unsafe { core::mem::transmute::, SharedArray>(input) }; + let out = if size_of::() == size_of::() + && align_of::() >= align_of::() + && input.is_unique() + { + unsafe { binary_scalar_simd_inplace::(input, elem) } + } else { + binary_scalar_simd_owned::(input, elem) + }; + // Used to assert traits based on the dynamic `DType`. + let out = unsafe { core::mem::transmute::, SharedArray>(out) }; + Ok(out) +} + +/// Execute operation in place on an owned tensor +/// SAFETY: +/// Must ensure `size_of:: == size_of::` and `align_of:: >= align_of::`. +unsafe fn binary_scalar_simd_inplace< + T: NdArrayElement + Scalar, + Out: NdArrayElement + Scalar, + Op: ScalarSimdBinop, +>( + input: SharedArray, + elem: Op::Rhs, +) -> SharedArray { + let mut buffer = input.into_owned(); + let slice = buffer.as_slice_memory_order_mut().unwrap(); + unsafe { binary_scalar_slice_inplace::(slice, elem, PhantomData) }; + // Buffer has the same elem size and is filled with the operation output, so this is safe + let out = unsafe { core::mem::transmute::, ArrayD>(buffer) }; + out.into_shared() +} + +/// Create a new copy of the tensor as the output +fn binary_scalar_simd_owned< + T: NdArrayElement + Scalar, + Out: NdArrayElement + Scalar, + Op: ScalarSimdBinop, +>( + input: SharedArray, + elem: Op::Rhs, +) -> SharedArray { + let mut out = uninit_array_like(&input); + let input = input.as_slice_memory_order().unwrap(); + let out_slice = out.as_slice_memory_order_mut().unwrap(); + binary_scalar_slice::(input, out_slice, elem, PhantomData); + out.into_shared() +} + +#[inline(always)] +#[allow(clippy::erasing_op, clippy::identity_op)] +#[macerator::with_simd] +fn binary_scalar_slice< + 'a, + S: Simd, + T: NdArrayElement + Scalar, + Out: NdArrayElement + Scalar, + Op: ScalarSimdBinop, +>( + input: &'a [T], + out: &'a mut [Out], + rhs: Op::Rhs, + _op: PhantomData, +) where + 'a: 'a, +{ + let lanes = T::lanes::(); + let mut chunks_input = input.chunks_exact(8 * lanes); + let mut chunks_out = out.chunks_exact_mut(8 * lanes); + let rhs_vec = Op::splat::(rhs); + while let Some((input, out)) = chunks_input.next().zip(chunks_out.next()) { + seq!(N in 0..8 { + // Load one full vector from `input`. + // SAFETY: Guaranteed to be in bounds because `len == 8 * lanes` + let s~N = unsafe { vload_unaligned(&input[N * lanes]) }; + let s~N = Op::apply_vec(s~N, rhs_vec); + // Store one full vector to `out`. + // SAFETY: Guaranteed to be in bounds because `len == 8 * lanes` + unsafe { vstore_unaligned(&mut out[N * lanes], s~N) }; + }); + } + let mut chunks_input = chunks_input.remainder().chunks_exact(lanes); + let mut chunks_out = chunks_out.into_remainder().chunks_exact_mut(lanes); + while let Some((input, out)) = chunks_input.next().zip(chunks_out.next()) { + // Load one full vector from `input`. + // SAFETY: Guaranteed to be in bounds because `len == lanes` + let s0 = unsafe { vload_unaligned(input.as_ptr()) }; + let s0 = Op::apply_vec(s0, rhs_vec); + // Store one full vector to `out`. + // SAFETY: Guaranteed to be in bounds because `len == lanes` + unsafe { vstore_unaligned(out.as_mut_ptr(), s0) }; + } + + for (input, out) in chunks_input + .remainder() + .iter() + .zip(chunks_out.into_remainder()) + { + *out = Op::apply(*input, rhs) + } +} + +/// Execute operation in line. +/// SAFETY: +/// Must ensure `size_of:: == size_of::` and `align_of:: >= align_of::`. +#[inline(always)] +#[macerator::with_simd] +unsafe fn binary_scalar_slice_inplace< + 'a, + S: Simd, + T: NdArrayElement + Scalar, + Out: NdArrayElement + Scalar, + Op: ScalarSimdBinop, +>( + buf: &'a mut [T], + rhs: Op::Rhs, + _op: PhantomData<(Out, Op)>, +) where + 'a: 'a, +{ + let (head, main, tail) = unsafe { buf.align_to_mut::>() }; + for elem in head.iter_mut().chain(tail) { + *elem = cast(Op::apply(*elem, rhs)); + } + let mut chunks = main.chunks_exact_mut(8); + let rhs = Op::splat::(rhs); + for elem in chunks.by_ref() { + seq!(N in 0..8 { + // Load a full vector from the aligned portion of the buffer. + // SAFETY: `align_to_mut` guarantees we're aligned to `T::Vector`'s size, and there is + // always a full vector in bounds. + let s~N = unsafe { vload(&elem[N] as *const _ as *const T) }; + let s~N = Op::apply_vec(s~N, rhs); + // Store a full vector at the same position as the input. Cast is safe because `Out` is + // size and align compatible + unsafe { vstore_unaligned(&mut elem[N] as *mut _ as *mut Out, s~N) }; + }); + } + + for elem in chunks.into_remainder() { + // Load a full vector from the aligned portion of the buffer. + // SAFETY: `align_to_mut` guarantees we're aligned to `T::Vector`'s size, and there is + // always a full vector in bounds. + let s0 = unsafe { vload(elem as *const _ as *const T) }; + + let s0 = Op::apply_vec(s0, rhs); + // Store a full vector at the same position as the input. Cast is safe because `Out` is + // size and align compatible + unsafe { vstore(elem as *mut _ as *mut Out, s0) }; + } +} diff --git a/crates/burn/src/ops/simd/cmp.rs b/crates/burn/src/ops/simd/cmp.rs new file mode 100644 index 00000000..c9f8c0ea --- /dev/null +++ b/crates/burn/src/ops/simd/cmp.rs @@ -0,0 +1,374 @@ +use core::{marker::PhantomData, slice}; + +use burn_backend::Element; +use macerator::{Mask, Scalar, Simd, VEq, VOrd, Vector, vload_unaligned}; +use ndarray::ArrayD; +use seq_macro::seq; + +use crate::{NdArrayElement, SharedArray, ops::simd::uninit_array_like}; + +use super::should_use_simd; + +pub trait SimdCmpOp { + fn apply_vec(lhs: Vector, rhs: Vector) -> Mask; + fn apply(lhs: T, rhs: T) -> bool; + fn is_accelerated() -> bool; +} + +pub struct VecEquals; + +impl SimdCmpOp for VecEquals { + fn apply_vec(lhs: Vector, rhs: Vector) -> Mask { + lhs.eq(rhs) + } + + fn apply(lhs: T, rhs: T) -> bool { + lhs == rhs + } + + fn is_accelerated() -> bool { + ::is_accelerated::() + } +} + +pub struct VecGreater; + +impl SimdCmpOp for VecGreater { + fn apply_vec(lhs: Vector, rhs: Vector) -> Mask { + lhs.gt(rhs) + } + + fn apply(lhs: T, rhs: T) -> bool { + lhs > rhs + } + + fn is_accelerated() -> bool { + ::is_cmp_accelerated::() + } +} + +pub struct VecGreaterEq; + +impl SimdCmpOp for VecGreaterEq { + fn apply_vec(lhs: Vector, rhs: Vector) -> Mask { + lhs.ge(rhs) + } + + fn apply(lhs: T, rhs: T) -> bool { + lhs >= rhs + } + + fn is_accelerated() -> bool { + ::is_cmp_accelerated::() + } +} + +pub struct VecLowerEq; + +impl SimdCmpOp for VecLowerEq { + fn apply_vec(lhs: Vector, rhs: Vector) -> Mask { + lhs.le(rhs) + } + + fn apply(lhs: T, rhs: T) -> bool { + lhs <= rhs + } + + fn is_accelerated() -> bool { + ::is_cmp_accelerated::() + } +} + +pub struct VecLower; + +impl SimdCmpOp for VecLower { + fn apply_vec(lhs: Vector, rhs: Vector) -> Mask { + lhs.lt(rhs) + } + + fn apply(lhs: T, rhs: T) -> bool { + lhs < rhs + } + + fn is_accelerated() -> bool { + ::is_cmp_accelerated::() + } +} + +#[macerator::with_simd] +fn is_accelerated>(_x: PhantomData<(T, Op)>) -> bool { + Op::is_accelerated::() +} + +#[allow(clippy::result_large_err)] +pub fn try_cmp_simd>( + lhs: SharedArray, + rhs: SharedArray, +) -> Result, (SharedArray, SharedArray)> { + let lhs_len = lhs.len(); + let rhs_len = rhs.len(); + if !should_use_simd(lhs_len.max(rhs_len)) + || !lhs.is_standard_layout() + || !rhs.is_standard_layout() + || lhs.shape() != rhs.shape() + || !is_accelerated::(PhantomData) + { + return Err((lhs, rhs)); + } + // Used to assert traits based on the dynamic `DType`. + let lhs = unsafe { core::mem::transmute::, SharedArray>(lhs) }; + let rhs = unsafe { core::mem::transmute::, SharedArray>(rhs) }; + let out = cmp_simd_same::(lhs, rhs); + + Ok(out) +} + +fn cmp_simd_same>( + lhs: SharedArray, + rhs: SharedArray, +) -> SharedArray { + let out = if lhs.is_unique() && size_of::() == size_of::() { + let mut buf = lhs.into_owned(); + let lhs = buf.as_slice_mut().unwrap(); + let rhs = rhs.as_slice().unwrap(); + let out = + unsafe { core::mem::transmute::<&mut [T], &mut [bool]>(unsafe_alias_slice_mut(lhs)) }; + cmp(lhs, rhs, out, PhantomData::); + unsafe { core::mem::transmute::, ArrayD>(buf) } + } else if rhs.is_unique() && size_of::() == size_of::() { + let mut buf = rhs.into_owned(); + let lhs = lhs.as_slice().unwrap(); + let rhs = buf.as_slice_mut().unwrap(); + let out = + unsafe { core::mem::transmute::<&mut [T], &mut [bool]>(unsafe_alias_slice_mut(rhs)) }; + cmp(lhs, rhs, out, PhantomData::); + unsafe { core::mem::transmute::, ArrayD>(buf) } + } else { + let mut out = uninit_array_like(&lhs); + let lhs = lhs.as_slice().unwrap(); + let rhs = rhs.as_slice().unwrap(); + let out_slice = out.as_slice_mut().unwrap(); + cmp(lhs, rhs, out_slice, PhantomData::); + out + }; + out.into_shared() +} + +#[allow(clippy::erasing_op, clippy::identity_op)] +#[macerator::with_simd] +fn cmp<'a, S: Simd, T: NdArrayElement + Scalar, Op: SimdCmpOp>( + lhs: &'a [T], + rhs: &'a [T], + out: &'a mut [bool], + _op: PhantomData, +) where + 'a: 'a, +{ + let lanes = T::lanes::(); + let mut chunks_lhs = lhs.chunks_exact(8 * lanes); + let mut chunks_rhs = rhs.chunks_exact(8 * lanes); + let mut chunks_out = out.chunks_exact_mut(8 * lanes); + while let Some(((lhs, rhs), out)) = chunks_lhs + .next() + .zip(chunks_rhs.next()) + .zip(chunks_out.next()) + { + seq!(N in 0..8 { + // Load one full vector from `lhs`. + // SAFETY: Guaranteed to be in bounds because `len == 8 * lanes` + let lhs~N = unsafe { vload_unaligned::(&lhs[N * lanes]) }; + // Load one full vector from `rhs`. + // SAFETY: Guaranteed to be in bounds because `len == 8 * lanes` + let rhs~N = unsafe { vload_unaligned(&rhs[N * lanes]) }; + let s~N = Op::apply_vec(lhs~N, rhs~N); + // Store one full vector to `out`. + // SAFETY: Guaranteed to be in bounds because `len == 8 * lanes` + unsafe { T::mask_store_as_bool(&mut out[N * lanes], s~N) }; + }); + } + let mut chunks_lhs = chunks_lhs.remainder().chunks_exact(lanes); + let mut chunks_rhs = chunks_rhs.remainder().chunks_exact(lanes); + let mut chunks_out = chunks_out.into_remainder().chunks_exact_mut(lanes); + while let Some(((lhs, rhs), out)) = chunks_lhs + .next() + .zip(chunks_rhs.next()) + .zip(chunks_out.next()) + { + // Load one full vector from `lhs`. + // SAFETY: Guaranteed to be in bounds because `len == lanes` + let lhs0 = unsafe { vload_unaligned::(lhs.as_ptr()) }; + // Load one full vector from `rhs`. + // SAFETY: Guaranteed to be in bounds because `len == lanes` + let rhs0 = unsafe { vload_unaligned(rhs.as_ptr()) }; + let s0 = Op::apply_vec(lhs0, rhs0); + // Store one full vector to `out`. + // SAFETY: Guaranteed to be in bounds because `len == lanes` + unsafe { T::mask_store_as_bool(out.as_mut_ptr(), s0) }; + } + + for ((lhs, rhs), out) in chunks_lhs + .remainder() + .iter() + .zip(chunks_rhs.remainder()) + .zip(chunks_out.into_remainder()) + { + *out = Op::apply(*lhs, *rhs) + } +} + +/// Unsafely alias a slice to use as an inline argument +fn unsafe_alias_slice_mut<'a, T>(slice: &mut [T]) -> &'a mut [T] { + let ptr = slice.as_mut_ptr(); + let len = slice.len(); + unsafe { slice::from_raw_parts_mut(ptr, len) } +} + +pub use elemwise::try_cmp_scalar_simd; + +mod elemwise { + use bytemuck::cast; + use macerator::vload; + + use super::*; + + pub fn try_cmp_scalar_simd>( + input: SharedArray, + elem: T, + ) -> Result, SharedArray> { + if !should_use_simd(input.len()) + || input.as_slice_memory_order().is_none() + || !is_accelerated::(PhantomData) + { + return Err(input); + } + // Used to assert traits based on the dynamic `DType`. + let input = unsafe { core::mem::transmute::, SharedArray>(input) }; + let out = if size_of::() == size_of::() + && align_of::() >= align_of::() + && input.is_unique() + { + unsafe { cmp_scalar_simd_inplace::(input, elem) } + } else { + cmp_scalar_simd_owned::(input, elem) + }; + Ok(out) + } + + /// Execute operation in place on an owned tensor + /// SAFETY: + /// Must ensure `size_of:: == size_of::` and `align_of:: >= align_of::`. + unsafe fn cmp_scalar_simd_inplace>( + input: SharedArray, + elem: T, + ) -> SharedArray { + let mut buffer = input.into_owned(); + let slice = buffer.as_slice_memory_order_mut().unwrap(); + unsafe { cmp_scalar_slice_inplace::(slice, elem, PhantomData) }; + // Buffer has the same elem size and is filled with the operation output, so this is safe + let out = unsafe { core::mem::transmute::, ArrayD>(buffer) }; + out.into_shared() + } + + /// Create a new copy of the tensor as the output + fn cmp_scalar_simd_owned>( + input: SharedArray, + elem: T, + ) -> SharedArray { + let mut out = uninit_array_like(&input); + let input = input.as_slice_memory_order().unwrap(); + let out_slice = out.as_slice_memory_order_mut().unwrap(); + cmp_scalar_slice::(input, out_slice, elem, PhantomData); + out.into_shared() + } + + #[inline(always)] + #[allow(clippy::erasing_op, clippy::identity_op)] + #[macerator::with_simd] + fn cmp_scalar_slice<'a, S: Simd, T: NdArrayElement + Scalar, Op: SimdCmpOp>( + input: &'a [T], + out: &'a mut [bool], + rhs: T, + _op: PhantomData, + ) where + 'a: 'a, + { + let lanes = T::lanes::(); + let mut chunks_input = input.chunks_exact(8 * lanes); + let mut chunks_out = out.chunks_exact_mut(8 * lanes); + let rhs_vec = rhs.splat::(); + while let Some((input, out)) = chunks_input.next().zip(chunks_out.next()) { + seq!(N in 0..8 { + // Load one full vector from `input`. + // SAFETY: Guaranteed to be in bounds because `len == 8 * lanes` + let s~N = unsafe { vload_unaligned(&input[N * lanes]) }; + let s~N = Op::apply_vec(s~N, rhs_vec); + // Store one full vector to `out`. + // SAFETY: Guaranteed to be in bounds because `len == 8 * lanes` + unsafe { T::mask_store_as_bool(&mut out[N * lanes], s~N) }; + }); + } + let mut chunks_input = chunks_input.remainder().chunks_exact(lanes); + let mut chunks_out = chunks_out.into_remainder().chunks_exact_mut(lanes); + while let Some((input, out)) = chunks_input.next().zip(chunks_out.next()) { + // Load one full vector from `input`. + // SAFETY: Guaranteed to be in bounds because `len == lanes` + let s0 = unsafe { vload_unaligned(input.as_ptr()) }; + let s0 = Op::apply_vec(s0, rhs_vec); + // Store one full vector to `out`. + // SAFETY: Guaranteed to be in bounds because `len == lanes` + unsafe { T::mask_store_as_bool(out.as_mut_ptr(), s0) }; + } + + for (input, out) in chunks_input + .remainder() + .iter() + .zip(chunks_out.into_remainder()) + { + *out = Op::apply(*input, rhs) + } + } + + /// Execute operation in line. + /// SAFETY: + /// Must ensure `size_of:: == size_of::` and `align_of:: >= align_of::`. + #[inline(always)] + #[macerator::with_simd] + unsafe fn cmp_scalar_slice_inplace<'a, S: Simd, T: NdArrayElement + Scalar, Op: SimdCmpOp>( + buf: &'a mut [T], + rhs: T, + _op: PhantomData, + ) where + 'a: 'a, + { + let (head, main, tail) = unsafe { buf.align_to_mut::>() }; + for elem in head.iter_mut().chain(tail) { + *elem = cast(Op::apply(*elem, rhs)); + } + let mut chunks = main.chunks_exact_mut(8); + let rhs = rhs.splat::(); + for elem in chunks.by_ref() { + seq!(N in 0..8 { + // Load a full vector from the aligned portion of the buffer. + // SAFETY: `align_to_mut` guarantees we're aligned to `T::Vector`'s size, and there is + // always a full vector in bounds. + let s~N = unsafe { vload(&elem[N] as *const _ as *const T) }; + let s~N = Op::apply_vec(s~N, rhs); + // Store a full vector at the same position as the input. Cast is safe because `Out` is + // size and align compatible + unsafe { T::mask_store_as_bool(&mut elem[N] as *mut _ as *mut bool, s~N) }; + }); + } + + for elem in chunks.into_remainder() { + // Load a full vector from the aligned portion of the buffer. + // SAFETY: `align_to_mut` guarantees we're aligned to `T::Vector`'s size, and there is + // always a full vector in bounds. + let s0 = unsafe { vload(elem as *const _ as *const T) }; + + let s0 = Op::apply_vec(s0, rhs); + // Store a full vector at the same position as the input. Cast is safe because `Out` is + // size and align compatible + unsafe { T::mask_store_as_bool(elem as *mut _ as *mut bool, s0) }; + } + } +} diff --git a/crates/burn/src/ops/simd/conv.rs b/crates/burn/src/ops/simd/conv.rs new file mode 100644 index 00000000..5bbd4633 --- /dev/null +++ b/crates/burn/src/ops/simd/conv.rs @@ -0,0 +1,494 @@ +use core::{marker::PhantomData, mem::transmute}; + +use burn_backend::{ + DType, Element, + ops::{ConvOptions, conv::calculate_conv_output_size}, +}; +use bytemuck::Zeroable; +use macerator::{Simd, VMulAdd, Vector, vload_unaligned, vstore_unaligned}; +use ndarray::{ + ArcArray1, Array4, ArrayView3, ArrayView4, ArrayViewMut2, ArrayViewMut3, Dim, Ix1, Ix4, s, +}; +use seq_macro::seq; + +use crate::{FloatNdArrayElement, SharedArray, UnsafeSharedRef, iter_range_par, run_par}; + +type Args = (SharedArray, SharedArray, Option>); + +#[allow(clippy::result_large_err)] +pub fn try_conv2d_simd( + x: SharedArray, + weight: SharedArray, + bias: Option>, + options: ConvOptions<2>, +) -> Result, Args> { + match E::dtype() { + DType::F64 => conv2d::(x, weight, bias, options, PhantomData), + DType::F32 => conv2d::(x, weight, bias, options, PhantomData), + DType::I64 => conv2d::(x, weight, bias, options, PhantomData), + DType::I32 => conv2d::(x, weight, bias, options, PhantomData), + DType::I16 => conv2d::(x, weight, bias, options, PhantomData), + DType::U64 => conv2d::(x, weight, bias, options, PhantomData), + DType::U32 => conv2d::(x, weight, bias, options, PhantomData), + DType::U16 => conv2d::(x, weight, bias, options, PhantomData), + _ => Err((x, weight, bias)), + } +} + +fn cast(tensor: SharedArray) -> SharedArray { + unsafe { transmute::, SharedArray>(tensor) } +} + +/// Out-channel last SIMD accelerated direct convolution. Loop order and register blocking based on +/// E. Georganas, S. Avancha, K. Banerjee, D. Kalamkar, G. Henry, H. Pabst, A. Heinecke (2018). +/// Anatomy Of High-Performance Deep Learning Convolutions On SIMD Architectures. +/// SC '18, Article 6, pp. 1-12. arXiv:1808.05567. . +#[allow(clippy::result_large_err)] +fn conv2d( + x: SharedArray, + weight: SharedArray, + bias: Option>, + options: ConvOptions<2>, + _ty: PhantomData, +) -> Result, Args> { + let [out_channels, _, k_height, k_width] = weight.shape().try_into().unwrap(); + let channels_per_group = out_channels / options.groups; + + #[macerator::with_simd] + fn precheck(_ty: PhantomData) -> (usize, bool) { + (E::lanes::(), E::is_accelerated::()) + } + + let (lanes, accelerated) = precheck::(PhantomData); + + if !accelerated || !channels_per_group.is_multiple_of(lanes) { + return Err((x, weight, bias)); + } + + let x = cast::<_, E>(x); + let weight = cast::<_, E>(weight); + let bias = bias.map(|bias| cast::<_, E>(bias)); + + let [batch_size, _in_channels, in_height, in_width] = x.shape().try_into().unwrap(); + let [dilate_h, dilate_w] = options.dilation; + let [stride_h, stride_w] = options.stride; + let [pad_h, pad_w] = options.padding; + let padded = options.padding != [0, 0]; + let strided = options.stride != [1, 1] || options.dilation != [1, 1]; + let grouped = options.groups != 1; + + let out_height = calculate_conv_output_size(k_height, stride_h, pad_h, dilate_h, in_height); + let out_width = calculate_conv_output_size(k_width, stride_w, pad_w, dilate_w, in_width); + + let x = x.into_dimensionality::().unwrap(); + let weights = weight.into_dimensionality::().unwrap(); + let weights = weights.permuted_axes([1, 2, 3, 0]); + let weights = weights.as_standard_layout(); + let bias = bias.map(|bias| bias.into_dimensionality::().unwrap()); + // floor division means `(oc_blocks - 1) * lanes` can never be greater than `out_channels - lanes`. + let oc_blocks = out_channels / lanes; + + let mut out = unsafe { + Array4::::uninit(Dim([batch_size, out_height, out_width, out_channels])).assume_init() + }; + let unsafe_shared_out = UnsafeSharedRef::new(&mut out); + + run_par!(|| { + // SAFETY: Slices are guaranteed to be non-overlapping, so having an unsafe shared reference + // is safe. `oc_blocks * lanes` must be `<= out_channels` to satisfy safety of inner function. + iter_range_par!(0, batch_size * oc_blocks).for_each(|k| unsafe { + let b = k / oc_blocks; + let ob = k % oc_blocks; + let x = x.slice(s![b, .., .., ..]); + let out = unsafe_shared_out.get(); + let mut out = out.slice_mut(s![b, .., .., ..]); + let w = weights.view(); + + match (padded, strided, grouped) { + (true, true, true) => { + conv2d_launch::(x, w, &bias, &mut out, &options, ob) + } + (true, false, true) => { + conv2d_launch::(x, w, &bias, &mut out, &options, ob) + } + (false, true, true) => { + conv2d_launch::(x, w, &bias, &mut out, &options, ob) + } + (false, false, true) => { + conv2d_launch::(x, w, &bias, &mut out, &options, ob) + } + (true, true, false) => { + conv2d_launch::(x, w, &bias, &mut out, &options, ob) + } + (true, false, false) => { + conv2d_launch::(x, w, &bias, &mut out, &options, ob) + } + (false, true, false) => { + conv2d_launch::(x, w, &bias, &mut out, &options, ob) + } + (false, false, false) => { + conv2d_launch::(x, w, &bias, &mut out, &options, ob) + } + } + }); + }); + + let output = out.permuted_axes([0, 3, 1, 2]); + Ok(cast(output.into_dyn().into_shared())) +} + +/// Size of register blocks, we need to hardcode this because Rust and the `seq` macro don't support +/// using associated constants as constant parameters. 8 works for all semi-modern CPUs but might +/// not be perfectly optimized for AVX-512 capable CPUs (which probably should use 16). +/// This should always be conservative, since oversizing it will cause register spills and that's +/// **much** worse than the performance lost with lower values. +const REGISTER_BLOCK: usize = 8; +inner_with_register_blocking_size!(8); + +/// Run a loop of conv2d. +/// # SAFETY +/// See `conv2d_inner_nopad`, `conv2d_inner_nopad_nostride`, `conv2d_remainder`. +/// Required preconditions: `ob * simd_lanes` must be `<= out_channels - simd_lanes`, `weights` and +/// `out` must have unit stride for the out channels. +#[inline(always)] +#[macerator::with_simd] +unsafe fn conv2d_launch< + 'a, + S: Simd, + E: VMulAdd, + const PAD: bool, + const STRIDE: bool, + const GROUPS: bool, +>( + x: ArrayView3<'a, E>, + weights: ArrayView4<'a, E>, + bias: &'a Option>, + out: &'a mut ArrayViewMut3<'a, E>, + options: &'a ConvOptions<2>, + ob: usize, +) where + 'a: 'a, +{ + let (in_channels, k_height, k_width, out_channels) = weights.dim(); + let (out_height, out_width, _) = out.dim(); + let channels_per_group = out_channels / options.groups; + let lanes = E::lanes::(); + + let [mut pad_h, mut pad_w] = options.padding; + let [stride_h, stride_w] = options.stride; + let [dilate_h, dilate_w] = options.dilation; + + // Trick compiler into inlining 0 to padding + if !PAD { + pad_h = 0; + pad_w = 0; + } + + let oc_b = channels_per_group.min(lanes); + let ow_b = REGISTER_BLOCK; + + let ow_start = pad_w; + let ow_width = out_width.saturating_sub(2 * pad_w); + let oh_start = pad_h; + let oh_end = out_height.saturating_sub(pad_h); + + let ow_blocks = ow_width / ow_b; + let oc = ob * oc_b; + let group = oc / channels_per_group; + let mut ic_off = group * in_channels; + if !GROUPS { + ic_off = 0; + } + + unsafe { + let bias = if let Some(bias) = &bias { + vload_unaligned::(&bias[oc]) + } else { + Zeroable::zeroed() + }; + + for oh in oh_start..oh_end { + let mut out = out.slice_mut(s![oh, .., ..]); + for ow_block in 0..ow_blocks { + let ow = ow_block * ow_b + ow_start; + + #[allow(clippy::if_same_then_else)] + if STRIDE { + conv2d_inner_nopad( + &x, &weights, &mut out, bias, oh, ow, oc, ic_off, stride_h, stride_w, + dilate_h, dilate_w, k_height, k_width, pad_h, pad_w, + ); + } else { + conv2d_inner_nopad_nostride( + &x, &weights, &mut out, bias, oh, ow, oc, ic_off, k_height, k_width, pad_h, + pad_w, + ); + } + } + } + conv2d_remainder( + x, + weights, + out, + bias, + oc, + ic_off, + ow_blocks * ow_b, + stride_h, + stride_w, + dilate_h, + dilate_w, + pad_h, + pad_w, + k_height, + k_width, + ); + } +} + +/// Execute the non-unrolled and/or padded portion of the convolution. This has more checks and is +/// much slower, so we want to minimize the amount of pixels that need to be processed by this +/// +/// SAFETY: `oc` must be an index that's at most `out_channels - simd_lanes`, so the full vector +/// is in bounds. Weights and `out` must be channels last (with `stride == 1`). +#[allow(clippy::too_many_arguments)] +#[inline(always)] +unsafe fn conv2d_remainder( + x: ArrayView3, + weights: ArrayView4, + out: &mut ArrayViewMut3, + bias: Vector, + oc: usize, + ic_off: usize, + owb_end: usize, + stride_h: usize, + stride_w: usize, + dilate_h: usize, + dilate_w: usize, + pad_h: usize, + pad_w: usize, + k_height: usize, + k_width: usize, +) { + let in_channels = weights.shape()[0]; + let (_, in_height, in_width) = x.dim(); + let (out_height, out_width, _) = out.dim(); + let oh_start = pad_h; + let oh_end = out_height.saturating_sub(pad_h); + let ow_start = pad_w; + + let height1 = in_height + pad_h; + let width1 = in_width + pad_w; + + for oh in (0..oh_start).chain(oh_end..out_height) { + for ow in 0..out_width { + let mut acc = bias; + + for ic in 0..in_channels { + for kh in 0..k_height { + let ih = oh * stride_h + kh * dilate_h; + if (ih < pad_h) | (ih >= height1) { + continue; + } + let ih = ih - pad_h; + + for kw in 0..k_width { + let iw = ow * stride_w + kw * dilate_w; + if (iw < pad_w) | (iw >= width1) { + continue; + } + let iw = iw - pad_w; + + // Load a full vector from the weights. This is guaranteed to be in bounds + // as long as `oc <= out_channels - simd_lanes` and out channels are last. + // We need to ensure the weights are reshaped appropriately. + let f0 = unsafe { vload_unaligned(&weights[[ic, kh, kw, oc]]) }; + + // The loop bounds ensure `ic`, `ih` and `iw` are always in bounds, but the + // compiler can't prove this. We can't use `as_slice` with fixed bounds + // because we want to support arbitrary input layouts. So an unchecked load + // is used. + let i0 = unsafe { x.uget([ic, ih, iw]) }.splat::(); + acc = i0.mul_add(f0, acc); + } + } + } + + // Store a full vector from the output. This is guaranteed to be in bounds + // as long as `oc <= out_channels - simd_lanes` and oc stride is 1. We create `out` with + // channels last, so this always holds. + unsafe { vstore_unaligned(&mut out[[oh, ow, oc]], acc) }; + } + } + for ow in (0..ow_start).chain(owb_end..out_width) { + for oh in 0..out_height { + let mut acc = bias; + + for ic in 0..in_channels { + for kh in 0..k_height { + let ih = oh * stride_h + kh * dilate_h; + if (ih < pad_h) | (ih >= height1) { + continue; + } + let ih = ih - pad_h; + + for kw in 0..k_width { + let iw = ow * stride_w + kw * dilate_w; + if (iw < pad_w) | (iw >= width1) { + continue; + } + let iw = iw - pad_w; + + // Load a full vector from the weights. This is guaranteed to be in bounds + // as long as `oc <= out_channels - simd_lanes` and out channels are last. + // We need to ensure the weights are reshaped appropriately. + let f0 = unsafe { vload_unaligned(&weights[[ic, kh, kw, oc]]) }; + + // The loop bounds ensure `ic`, `ih` and `iw` are always in bounds, but the + // compiler can't prove this. We can't use `as_slice` with fixed bounds + // because we want to support arbitrary input layouts. So an unchecked load + // is used. + let i0 = unsafe { x.uget([ic_off + ic, ih, iw]) }.splat::(); + acc = i0.mul_add(f0, acc); + } + } + } + + // Store a full vector from the output. This is guaranteed to be in bounds + // as long as `oc <= out_channels - simd_lanes` and oc stride is 1. We create `out` with + // channels last, so this always holds. + unsafe { vstore_unaligned(&mut out[[oh, ow, oc]], acc) }; + } + } +} + +macro_rules! inner_with_register_blocking_size { + ($rb: literal) => { + /// Execute the unrolled and unpadded portion of the convolution. Any pixel that is more than + /// `pad_h` away from the horizontal border, and `pad_w` away from the vertical border is + /// guaranteed to always be in bounds (because of the way out size is calculated). + /// + /// SAFETY: `oc` must be an index that's at most `out_channels - simd_lanes`, so the full vector + /// is in bounds. Weights and `out` must be channels last (with `stride == 1`). + #[allow(clippy::erasing_op, clippy::identity_op, clippy::too_many_arguments)] + #[inline(always)] + unsafe fn conv2d_inner_nopad( + x: &ArrayView3, + weights: &ArrayView4, + out: &mut ArrayViewMut2, + bias: Vector, + oh: usize, + ow: usize, + oc: usize, + ic_off: usize, + stride_h: usize, + stride_w: usize, + dilate_h: usize, + dilate_w: usize, + k_height: usize, + k_width: usize, + pad_h: usize, + pad_w: usize, + ) { + let in_channels = weights.shape()[0]; + + seq!(N in 0..$rb { + let mut acc~N = bias; + }); + + for ic in 0..in_channels { + for kh in 0..k_height { + let ih = oh * stride_h + kh * dilate_h - pad_h; + + for kw in 0..k_width { + // Load a full vector from the weights. This is guaranteed to be in bounds + // as long as `oc <= out_channels - simd_lanes` and out channels are last. + // We need to ensure the weights are reshaped appropriately. + let f0 = unsafe { vload_unaligned(&weights[[ic, kh, kw, oc]]) }; + let iw = ow * stride_w + kw * dilate_w - pad_w; + + seq!(N in 0..$rb { + // The loop bounds ensure `ic`, `ih` and `iw` are always in bounds, but the + // compiler can't prove this. We can't use `as_slice` with fixed bounds + // because we want to support arbitrary input layouts. So an unchecked load + // is used. + let i~N = unsafe { x.uget([ic + ic_off, ih, iw + N * stride_w]) }.splat::(); + }); + seq!(N in 0..$rb { + acc~N = i~N.mul_add(f0, acc~N); + }); + } + } + } + + seq!(N in 0..$rb { + // Store a full vector from the output. This is guaranteed to be in bounds + // as long as `oc <= out_channels - simd_lanes` and oc stride is 1. We create `out` with + // channels last, so this always holds. + unsafe { vstore_unaligned(&mut out[[ow + N, oc]], acc~N) }; + }); + } + + /// Execute the unrolled and unpadded portion of the convolution. Any pixel that is more than + /// `pad_h` away from the horizontal border, and `pad_w` away from the vertical border is + /// guaranteed to always be in bounds (because of the way out size is calculated). + /// + /// SAFETY: `oc` must be an index that's at most `out_channels - simd_lanes`, so the full vector + /// is in bounds. Weights and `out` must be channels last (with `stride == 1`). + #[allow(clippy::erasing_op, clippy::identity_op, clippy::too_many_arguments)] + #[inline(always)] + unsafe fn conv2d_inner_nopad_nostride( + x: &ArrayView3, + weights: &ArrayView4, + out: &mut ArrayViewMut2, + bias: Vector, + oh: usize, + ow: usize, + oc: usize, + ic_off: usize, + k_height: usize, + k_width: usize, + pad_h: usize, + pad_w: usize, + ) { + let in_channels = weights.shape()[0]; + + seq!(N in 0..$rb { + let mut acc~N = bias; + }); + + for ic in 0..in_channels { + for kh in 0..k_height { + let ih = oh + kh - pad_h; + + for kw in 0..k_width { + // Load a full vector from the weights. This is guaranteed to be in bounds + // as long as `oc <= out_channels - simd_lanes` and out channels are last. + // We need to ensure the weights are reshaped appropriately. + let f0 = unsafe { vload_unaligned(&weights[[ic, kh, kw, oc]]) }; + let iw = ow + kw - pad_w; + + seq!(N in 0..$rb { + // The loop bounds ensure `ic`, `ih` and `iw` are always in bounds, but the + // compiler can't prove this. We can't use `as_slice` with fixed bounds + // because we want to support arbitrary input layouts. So an unchecked load + // is used. + let i~N = unsafe { x.uget([ic + ic_off, ih, iw + N]) }.splat::(); + }); + seq!(N in 0..$rb { + acc~N = i~N.mul_add(f0, acc~N); + }); + } + } + } + + seq!(N in 0..$rb { + // Store a full vector from the output. This is guaranteed to be in bounds + // as long as `oc <= out_channels - simd_lanes` and oc stride is 1. We create `out` with + // channels last, so this always holds. + unsafe { vstore_unaligned(&mut out[[ow + N, oc]], acc~N) }; + }); + } + }; +} +pub(crate) use inner_with_register_blocking_size; diff --git a/crates/burn/src/ops/simd/maxpool.rs b/crates/burn/src/ops/simd/maxpool.rs new file mode 100644 index 00000000..279af69b --- /dev/null +++ b/crates/burn/src/ops/simd/maxpool.rs @@ -0,0 +1,394 @@ +use core::{marker::PhantomData, mem::transmute}; + +use crate::{SharedArray, iter_range_par, run_par, sharing::UnsafeSharedRef}; + +use burn_backend::{BoolStore, DType, Element, quantization::QuantValue}; +use macerator::{Simd, VOrd}; +use ndarray::{Array4, s}; +use nhwc::max_pool2d_nhwc; + +use super::{MinMax, should_use_simd}; + +#[macerator::with_simd] +fn is_accelerated_impl(_x: PhantomData) -> bool { + ::is_min_max_accelerated::() +} + +fn is_accelerated() -> bool { + is_accelerated_impl::(PhantomData) +} + +macro_rules! launch_kernel { + ($ty: ty, $func: ident, $x: expr, $($arg: expr),*) => { + match <$ty as Element>::dtype() { + DType::F64 if is_accelerated::() => Ok(cast($func::(cast($x), $($arg),*))), + DType::F32 if is_accelerated::() => Ok(cast($func::(cast($x), $($arg),*))), + DType::I64 if is_accelerated::() => Ok(cast($func::(cast($x), $($arg),*))), + DType::I32 if is_accelerated::() => Ok(cast($func::(cast($x), $($arg),*))), + DType::I16 if is_accelerated::() => Ok(cast($func::(cast($x), $($arg),*))), + DType::I8 if is_accelerated::() => Ok(cast($func::(cast($x), $($arg),*))), + DType::U64 if is_accelerated::() => Ok(cast($func::(cast($x), $($arg),*))), + DType::U32 if is_accelerated::() => Ok(cast($func::(cast($x), $($arg),*))), + DType::U16 if is_accelerated::() => Ok(cast($func::(cast($x), $($arg),*))), + DType::U8 if is_accelerated::() => Ok(cast($func::(cast($x), $($arg),*))), + DType::Bool(BoolStore::Native) if is_accelerated::() => Ok(cast($func::(cast($x), $($arg),*))), + DType::QFloat(scheme) => match scheme.value { + QuantValue::Q8F | QuantValue::Q8S if is_accelerated::() => Ok(cast($func::(cast($x), $($arg),*))), + _ => Err($x) + }, + _ => Err($x), + } + }; +} + +pub(crate) fn try_max_pool2d_simd( + x: SharedArray, + ksize: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], +) -> Result, SharedArray> { + let [_, c, _, _] = x.shape().try_into().unwrap(); + if !should_use_simd(c) || x.strides()[1] != 1 { + return Err(x); + } + + launch_kernel!(E, max_pool2d_nhwc, x, ksize, stride, padding, dilation) +} + +fn cast(tensor: SharedArray) -> SharedArray { + unsafe { transmute::, SharedArray>(tensor) } +} + +mod nhwc { + use itertools::Itertools; + use macerator::{Simd, vload_unaligned, vstore_unaligned}; + use ndarray::{ArrayView3, ArrayViewMut3, Ix4}; + use seq_macro::seq; + + use crate::ops::simd::lanes; + + use super::*; + + // Until you can use associated constants as array size, we need to hardcode this. + // The most common config (x86-v3) has 16 registers, so use half of them for accumulators. + const BLOCK_REGISTERS: usize = 8; + + pub(crate) fn max_pool2d_nhwc( + x: SharedArray, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], + ) -> SharedArray { + let [kernel_height, kernel_width] = kernel_size; + let [pad_h, pad_w] = padding; + let [stride_height, stride_width] = stride; + let [dilation_height, dilation_width] = dilation; + let [batch_size, channels, x_height, x_width] = x.shape().try_into().unwrap(); + let lanes = lanes::(); + + let ch_block = lanes * BLOCK_REGISTERS; + + let out_height = ((x_height + 2 * pad_h - dilation_height * (kernel_height - 1) - 1) + / stride_height) + + 1; + let out_width = + ((x_width + 2 * pad_w - dilation_width * (kernel_width - 1) - 1) / stride_width) + 1; + + let mut output = unsafe { + Array4::::uninit((batch_size, out_height, out_width, channels)).assume_init() + }; + let unsafe_shared_out = UnsafeSharedRef::new(&mut output); + + let x = x.into_dimensionality::().unwrap(); + let x = x.view(); + let x = x.permuted_axes([0, 2, 3, 1]); + + // Floor division ensures `blocks * lanes * blocking factor` is always `<= out_channels`. + // An exclusive loop will always have `lanes * blocking factor` elements in bounds. + let blocks = channels / ch_block; + let blocks_end = blocks * ch_block; + // Floor division means simd_end is always divisible by `lanes` and `<= out_channels`. An + // exclusive loop will always have `lanes` elements in bounds. + let simd_end = channels / lanes * lanes; + let simd_unblocked = (simd_end - blocks_end) / lanes; + let remainder = channels - simd_end; + + run_par!(|| { + // SAFETY: Loop ranges are non-overlapping, so the unsafe shared reference is safe. + iter_range_par!(0, batch_size * blocks).for_each(|k| unsafe { + let block = k % blocks; + let b = k / blocks; + + let output = unsafe_shared_out.get(); + let x = x.slice(s![b, .., .., ..]); + let out = output.slice_mut(s![b, .., .., ..]); + loop_blocked(x, out, kernel_size, stride, padding, dilation, block); + }); + // SAFETY: See `loop_unblocked` + iter_range_par!(0, batch_size * simd_unblocked).for_each(|k| unsafe { + let ch = (k % simd_unblocked) * lanes + blocks_end; + let b = k / simd_unblocked; + + let output = unsafe_shared_out.get(); + let x = x.slice(s![b, .., .., ..]); + let out = output.slice_mut(s![b, .., .., ..]); + loop_unblocked(x, out, kernel_size, stride, padding, dilation, ch); + }); + // SAFETY: Loop ranges are non-overlapping, so the unsafe shared reference is safe. + iter_range_par!(0, batch_size * remainder).for_each(|k| unsafe { + let ch = (k % remainder) + simd_end; + let b = k / remainder; + + let output = unsafe_shared_out.get(); + let x = x.slice(s![b, .., .., ..]); + let out = output.slice_mut(s![b, .., .., ..]); + loop_scalar(x, out, kernel_size, stride, padding, dilation, ch); + }); + }); + + output = output.permuted_axes([0, 3, 1, 2]); + + output.into_dyn().into_shared() + } + + /// Execute the blocked (unrolled) portion of the pool. + #[allow( + clippy::too_many_arguments, + clippy::erasing_op, + clippy::identity_op, + unused_mut + )] + #[inline(always)] + #[macerator::with_simd] + fn loop_blocked<'a, S: Simd, E: Element + VOrd + MinMax>( + x: ArrayView3<'a, E>, + mut out: ArrayViewMut3<'a, E>, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], + block: usize, + ) where + 'a: 'a, + { + let [kernel_height, kernel_width] = kernel_size; + let [pad_h, pad_w] = padding; + let [stride_height, stride_width] = stride; + let [dilation_height, dilation_width] = dilation; + + let (x_height, x_width, _) = x.dim(); + let (out_height, out_width, _) = out.dim(); + let lanes = E::lanes::(); + let ch_block = lanes * BLOCK_REGISTERS; + + let min = E::MIN.splat::(); + // If outside padding area, kernels are guaranteed to be in bounds + for oh in pad_h..out_height.saturating_sub(pad_h) { + for ow in pad_w..out_width.saturating_sub(pad_w) { + seq!(N in 0..8 { + let mut acc~N = min; + }); + let ch = block * ch_block; + let ch_end = ch + ch_block; + let mut out = out.slice_mut(s![oh, ow, ch..ch_end]); + + for kh in 0..kernel_height { + let ih = oh * stride_height + kh * dilation_height - pad_h; + + for kw in 0..kernel_width { + let iw = ow * stride_width + kw * dilation_width - pad_w; + let x = x.slice(s![ih, iw, ch..ch_end]); + + seq!(N in 0..8 { + // SAFETY: + // Load a full vector from x[N * lanes]. This is bounds checked by the + // slice above. + acc~N = acc~N.max(unsafe { vload_unaligned(&x[N * lanes]) }); + }); + } + } + + seq!(N in 0..8 { + // SAFETY: + // Store a full vector to out[N * lanes]. This is bounds checked by the + // slice above. + unsafe { vstore_unaligned(&mut out[N * lanes], acc~N) }; + }); + } + } + + // Border pixels need bounds checks + if (pad_h, pad_w) != (0, 0) { + let v_borders = (0..pad_h) + .chain(out_height.saturating_sub(pad_h)..out_height) + .cartesian_product(0..out_width); + let h_borders = (0..out_height) + .cartesian_product((0..pad_w).chain(out_width.saturating_sub(pad_w)..out_width)); + + for (oh, ow) in v_borders.chain(h_borders) { + seq!(N in 0..8 { + let mut acc~N = min; + }); + let ch = block * ch_block; + let ch_end = ch + ch_block; + let mut out = out.slice_mut(s![oh, ow, ch..ch_end]); + + for kh in 0..kernel_height { + let ih = oh * stride_height + kh * dilation_height; + if ih < pad_h || ih >= x_height + pad_h { + continue; + } + let ih = ih - pad_h; + + for kw in 0..kernel_width { + let iw = ow * stride_width + kw * dilation_width; + if iw < pad_w || iw >= x_width + pad_w { + continue; + } + let iw = iw - pad_w; + + let x = x.slice(s![ih, iw, ch..ch_end]); + + seq!(N in 0..8 { + // SAFETY: + // Load a full vector from x[N * lanes]. This is bounds checked by the + // slice above. + acc~N = acc~N.max(unsafe { vload_unaligned(&x[N * lanes]) }); + }); + } + } + + seq!(N in 0..8 { + // SAFETY: + // Store a full vector to out[N * lanes]. This is bounds checked by the + // slice above. + unsafe { vstore_unaligned(&mut out[N * lanes], acc~N) }; + }); + } + } + } + + /// Execute the unblocked (not unrolled) portion of the pool. + /// + /// SAFETY: Safe as long as `ch + simd_lanes <= out_channels`. + #[allow(clippy::too_many_arguments, unused_mut)] + #[inline(always)] + #[macerator::with_simd] + unsafe fn loop_unblocked<'a, S: Simd, E: Element + VOrd + MinMax>( + x: ArrayView3<'a, E>, + mut out: ArrayViewMut3<'a, E>, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], + ch: usize, + ) where + 'a: 'a, + { + let [kernel_height, kernel_width] = kernel_size; + let [pad_h, pad_w] = padding; + let [stride_height, stride_width] = stride; + let [dilation_height, dilation_width] = dilation; + + let (x_height, x_width, _) = x.dim(); + let (out_height, out_width, _) = out.dim(); + + for oh in pad_h..out_height.saturating_sub(pad_h) { + for ow in pad_w..out_width.saturating_sub(pad_w) { + let mut acc = E::MIN.splat::(); + let out = &mut out[[oh, ow, ch]]; + + for kh in 0..kernel_height { + let ih = oh * stride_height + kh * dilation_height - pad_h; + + for kw in 0..kernel_width { + let iw = ow * stride_width + kw * dilation_width - pad_w; + // Load a full vector from `x`. In bounds as long as `out_channels >= ch + lanes` + acc = acc.max(unsafe { vload_unaligned(&x[[ih, iw, ch]]) }); + } + } + // Store a full vector to `out`. In bounds as long as `out_channels >= ch + lanes`. + unsafe { vstore_unaligned(out, acc) }; + } + } + + // Border pixels need bounds checks + if (pad_h, pad_w) != (0, 0) { + let v_borders = (0..pad_h) + .chain(out_height.saturating_sub(pad_h)..out_height) + .cartesian_product(0..out_width); + let h_borders = (0..out_height) + .cartesian_product((0..pad_w).chain(out_width.saturating_sub(pad_w)..out_width)); + + for (oh, ow) in v_borders.chain(h_borders) { + let mut acc = E::MIN.splat::(); + let out = &mut out[[oh, ow, ch]]; + + for kh in 0..kernel_height { + let ih = oh * stride_height + kh * dilation_height; + if ih < pad_h || ih >= x_height + pad_h { + continue; + } + let ih = ih - pad_h; + + for kw in 0..kernel_width { + let iw = ow * stride_width + kw * dilation_width; + if iw < pad_w || iw >= x_width + pad_w { + continue; + } + let iw = iw - pad_w; + // Load a full vector from `x`. In bounds as long as `out_channels >= ch + lanes` + acc = acc.max(unsafe { vload_unaligned(&x[[ih, iw, ch]]) }); + } + } + // Store a full vector to `out`. In bounds as long as `out_channels >= ch + lanes`. + unsafe { vstore_unaligned(out, acc) }; + } + } + } + + fn loop_scalar( + x: ArrayView3<'_, E>, + mut out: ArrayViewMut3<'_, E>, + kernel_size: [usize; 2], + stride: [usize; 2], + padding: [usize; 2], + dilation: [usize; 2], + ch: usize, + ) { + let [kernel_height, kernel_width] = kernel_size; + let [pad_h, pad_w] = padding; + let [stride_height, stride_width] = stride; + let [dilation_height, dilation_width] = dilation; + + let (x_height, x_width, _) = x.dim(); + let (out_height, out_width, _) = out.dim(); + + for oh in 0..out_height { + for ow in 0..out_width { + let mut acc = E::MIN; + + for kh in 0..kernel_height { + let ih = oh * stride_height + kh * dilation_height; + if ih < pad_h || ih >= x_height + pad_h { + continue; + } + let ih = ih - pad_h; + + for kw in 0..kernel_width { + let iw = ow * stride_width + kw * dilation_width; + if iw < pad_w || iw >= x_width + pad_w { + continue; + } + let iw = iw - pad_w; + acc = acc.max(x[[ih, iw, ch]]); + } + } + + out[[oh, ow, ch]] = acc; + } + } + } +} diff --git a/crates/burn/src/ops/simd/mod.rs b/crates/burn/src/ops/simd/mod.rs new file mode 100644 index 00000000..2032f30c --- /dev/null +++ b/crates/burn/src/ops/simd/mod.rs @@ -0,0 +1,10 @@ +pub(crate) mod avgpool; +mod base; +pub(crate) mod binary; +pub(crate) mod binary_elemwise; +pub(crate) mod cmp; +pub(crate) mod conv; +pub(crate) mod maxpool; +pub(crate) mod unary; + +pub use base::*; diff --git a/crates/burn/src/ops/simd/unary.rs b/crates/burn/src/ops/simd/unary.rs new file mode 100644 index 00000000..68d26267 --- /dev/null +++ b/crates/burn/src/ops/simd/unary.rs @@ -0,0 +1,234 @@ +use core::marker::PhantomData; + +use bytemuck::cast; +use macerator::{ + Scalar, Simd, VAbs, VBitNot, VRecip, Vector, vload, vload_unaligned, vstore, vstore_unaligned, +}; +use ndarray::ArrayD; +use num_traits::Signed; +use seq_macro::seq; + +use crate::{NdArrayElement, SharedArray}; + +use super::should_use_simd; + +pub trait SimdUnop { + fn apply_vec(input: Vector) -> Vector; + fn apply(input: T) -> Out; + fn is_accelerated() -> bool; +} + +pub struct RecipVec; + +impl SimdUnop for RecipVec { + fn apply_vec(input: Vector) -> Vector { + input.recip() + } + + fn apply(input: f32) -> f32 { + input.recip() + } + + fn is_accelerated() -> bool { + ::is_accelerated::() + } +} + +pub struct VecAbs; + +impl SimdUnop for VecAbs { + fn apply_vec(input: Vector) -> Vector { + input.abs() + } + + fn apply(input: T) -> T { + input.abs() + } + + fn is_accelerated() -> bool { + ::is_accelerated::() + } +} + +pub struct VecBitNot; + +impl SimdUnop for VecBitNot { + fn apply_vec(input: Vector) -> Vector { + !input + } + + fn apply(input: T) -> T { + input.not() + } + + fn is_accelerated() -> bool { + ::is_accelerated::() + } +} + +#[macerator::with_simd] +fn is_accelerated>( + _x: PhantomData<(T, Out, Op)>, +) -> bool { + Op::is_accelerated::() +} + +pub fn try_unary_simd< + E: NdArrayElement, + EOut: NdArrayElement, + T: NdArrayElement + Scalar, + Out: NdArrayElement + Scalar, + Op: SimdUnop, +>( + input: SharedArray, +) -> Result, SharedArray> { + if !should_use_simd(input.len()) + || input.as_slice_memory_order().is_none() + || !is_accelerated::(PhantomData) + { + return Err(input); + } + // Used to assert traits based on the dynamic `DType`. + let input = unsafe { core::mem::transmute::, SharedArray>(input) }; + let out = if size_of::() == size_of::() + && align_of::() >= align_of::() + && input.is_unique() + { + unsafe { unary_scalar_simd_inplace::(input) } + } else { + unary_scalar_simd_owned::(input) + }; + // Used to assert traits based on the dynamic `DType`. + let out = unsafe { core::mem::transmute::, SharedArray>(out) }; + Ok(out) +} + +/// Execute operation in line. +/// SAFETY: +/// Must ensure `size_of:: == size_of::` and `align_of:: >= align_of::`. +unsafe fn unary_scalar_simd_inplace< + T: NdArrayElement + Scalar, + Out: NdArrayElement + Scalar, + Op: SimdUnop, +>( + input: SharedArray, +) -> SharedArray { + let mut buffer = input.into_owned(); + let slice = buffer.as_slice_memory_order_mut().unwrap(); + // This is only called when in and out have the same size, so it's safe + unsafe { unary_slice_inplace::(slice, PhantomData) }; + // Buffer has the same elem size and is filled with the operation output, so this is safe + let out = unsafe { core::mem::transmute::, ArrayD>(buffer) }; + out.into_shared() +} + +fn unary_scalar_simd_owned< + T: NdArrayElement + Scalar, + Out: NdArrayElement + Scalar, + Op: SimdUnop, +>( + input: SharedArray, +) -> SharedArray { + let mut out = unsafe { ArrayD::uninit(input.shape()).assume_init() }; + let input = input.as_slice_memory_order().unwrap(); + let out_slice = out.as_slice_memory_order_mut().unwrap(); + unary_slice::(input, out_slice, PhantomData); + out.into_shared() +} + +#[allow(clippy::erasing_op, clippy::identity_op)] +#[macerator::with_simd] +fn unary_slice< + 'a, + S: Simd, + T: NdArrayElement + Scalar, + Out: NdArrayElement + Scalar, + Op: SimdUnop, +>( + input: &'a [T], + out: &'a mut [Out], + _op: PhantomData, +) where + 'a: 'a, +{ + let lanes = T::lanes::(); + let mut chunks_input = input.chunks_exact(8 * lanes); + let mut chunks_out = out.chunks_exact_mut(8 * lanes); + while let Some((input, out)) = chunks_input.next().zip(chunks_out.next()) { + seq!(N in 0..8 { + // Load one full vector from `input`. + // SAFETY: Guaranteed to be in bounds because `len == 8 * lanes` + let s~N = unsafe { vload_unaligned(&input[N * lanes]) }; + let s~N = Op::apply_vec::(s~N); + // Store one full vector to `out`. + // SAFETY: Guaranteed to be in bounds because `len == 8 * lanes` + unsafe { vstore_unaligned(&mut out[N * lanes], s~N) }; + }); + } + let mut chunks_input = chunks_input.remainder().chunks_exact(lanes); + let mut chunks_out = chunks_out.into_remainder().chunks_exact_mut(lanes); + while let Some((input, out)) = chunks_input.next().zip(chunks_out.next()) { + // Load one full vector from `input`. + // SAFETY: Guaranteed to be in bounds because `len == lanes` + let s0 = unsafe { vload_unaligned(input.as_ptr()) }; + let s0 = Op::apply_vec::(s0); + // Store one full vector to `out`. + // SAFETY: Guaranteed to be in bounds because `len == lanes` + unsafe { vstore_unaligned(out.as_mut_ptr(), s0) }; + } + + for (input, out) in chunks_input + .remainder() + .iter() + .zip(chunks_out.into_remainder()) + { + *out = Op::apply(*input) + } +} + +/// Execute operation in line. +/// SAFETY: +/// Must ensure `size_of:: == size_of::` and `align_of:: >= align_of::`. +#[macerator::with_simd] +unsafe fn unary_slice_inplace< + 'a, + S: Simd, + T: NdArrayElement + Scalar, + Out: NdArrayElement + Scalar, + Op: SimdUnop, +>( + buf: &'a mut [T], + _op: PhantomData<(Out, Op)>, +) where + 'a: 'a, +{ + let (head, main, tail) = unsafe { buf.align_to_mut::>() }; + for elem in head.iter_mut().chain(tail) { + *elem = cast(Op::apply(*elem)); + } + let mut chunks = main.chunks_exact_mut(8); + for elem in chunks.by_ref() { + seq!(N in 0..8 { + // Load a full vector from the aligned portion of the buffer. + // SAFETY: `align_to_mut` guarantees we're aligned to `T::Vector`'s size, and there is + // always a full vector in bounds. + let s~N = unsafe { vload(&elem[N] as *const _ as *const T) }; + let s~N = Op::apply_vec::(s~N); + // Store a full vector at the same position as the input. Cast is safe because `Out` is + // size and align compatible + unsafe { vstore(&mut elem[N] as *mut _ as *mut Out, s~N) }; + }); + } + + for elem in chunks.into_remainder() { + // Load a full vector from the aligned portion of the buffer. + // SAFETY: `align_to_mut` guarantees we're aligned to `T::Vector`'s size, and there is + // always a full vector in bounds. + let s0 = unsafe { vload(elem as *const _ as *const T) }; + + let s0 = Op::apply_vec::(s0); + // Store a full vector at the same position as the input. Cast is safe because `Out` is + // size and align compatible + unsafe { vstore(elem as *mut _ as *mut Out, s0) }; + } +} diff --git a/crates/burn/src/ops/tensor.rs b/crates/burn/src/ops/tensor.rs new file mode 100644 index 00000000..c5f26f38 --- /dev/null +++ b/crates/burn/src/ops/tensor.rs @@ -0,0 +1,801 @@ +// Language +use alloc::vec::Vec; +use burn_backend::backend::ExecutionError; +use burn_backend::ops::GridSampleOptions; +use burn_backend::tensor::FloatTensor; +use burn_backend::{TensorMetadata, element::cast::ToElement}; +use burn_std::{BoolDType, IntDType}; + +// Current crate +use super::{ + NdArrayMathOps, NdArrayOps, + matmul::{cross, matmul}, +}; +use crate::{ + NdArray, cast_to_dtype, cat_with_dtype, execute_with_int_dtype, tensor::NdArrayTensor, +}; +use crate::{NdArrayDevice, SEED, execute_with_float_out_dtype, execute_with_int_out_dtype, slice}; +use crate::{ + SharedArray, + element::{ExpElement, FloatNdArrayElement, IntNdArrayElement, QuantElement}, +}; +use crate::{execute_with_float_dtype, ops::grid_sample::grid_sample_2d}; + +// Workspace crates +use crate::rand::get_seeded_rng; +use burn_backend::{Distribution, FloatDType, Scalar}; +use burn_backend::{ElementConversion, Shape, TensorData, backend::Backend, ops::FloatTensorOps}; + +#[cfg(not(feature = "std"))] +#[allow(unused_imports)] +use num_traits::Float; + +use libm::erf; + +/// Try to accelerate a unary f32 operation via ndarray's hpc::vml (F32x16 SIMD). +/// +/// VML signature: `fn(input: &[f32], output: &mut [f32])`. +/// Uses crate::simd::F32x16 internally. Consumer never sees hardware details. +#[cfg(feature = "simd")] +fn try_vml_unary( + tensor: NdArrayTensor, + vml_fn: fn(&[f32], &mut [f32]), +) -> Result { + if let NdArrayTensor::F32(storage) = tensor { + let shared = storage.into_shared(); + if shared.is_standard_layout() { + if let Some(input) = shared.as_slice() { + let mut output = vec![0.0f32; input.len()]; + vml_fn(input, &mut output); + let shape = shared.shape().to_vec(); + let array = ndarray::Array::from_shape_vec(ndarray::IxDyn(&shape), output) + .expect("vml output shape mismatch"); + return Ok(NdArrayTensor::F32( + crate::NdArrayStorage::Owned(array.into_shared()), + )); + } + } + return Err(NdArrayTensor::F32(crate::NdArrayStorage::Owned(shared))); + } + Err(tensor) +} + +#[cfg(feature = "std")] +#[allow(dead_code)] +fn round_ties_even_wrapper(x: f64) -> f64 { + x.round_ties_even() +} + +#[cfg(not(feature = "std"))] +#[allow(dead_code)] +fn round_ties_even_wrapper(x: f64) -> f64 { + if (x - x.floor()) == 0.5 { + (x * 0.5).round() * 2.0 + } else { + x.round() + } +} + +impl FloatTensorOps + for NdArray +where + NdArrayTensor: From>, + NdArrayTensor: From>, +{ + fn float_from_data(data: TensorData, _device: &NdArrayDevice) -> FloatTensor { + NdArrayTensor::from_data(data) + } + + fn float_random( + shape: Shape, + distribution: Distribution, + device: &NdArrayDevice, + dtype: FloatDType, + ) -> FloatTensor { + let mut seed = SEED.lock().unwrap(); + let mut rng = seed.take().unwrap_or_else(get_seeded_rng); + let tensor = execute_with_float_out_dtype!( + dtype, + E, + Self::float_from_data( + TensorData::random::(shape, distribution, &mut rng), + device, + ) + ); + + *seed = Some(rng); + tensor + } + + async fn float_into_data(tensor: FloatTensor) -> Result { + Ok(tensor.into_data()) + } + + fn float_device(_tensor: &FloatTensor) -> NdArrayDevice { + NdArrayDevice::Cpu + } + + fn float_to_device(tensor: FloatTensor, _device: &NdArrayDevice) -> FloatTensor { + tensor + } + + fn float_empty( + shape: Shape, + device: & as Backend>::Device, + dtype: FloatDType, + ) -> FloatTensor { + Self::float_zeros(shape, device, dtype) + } + + fn float_add(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { + execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::add) + } + + fn float_add_scalar(lhs: FloatTensor, rhs: Scalar) -> FloatTensor { + execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray| { + NdArrayMathOps::add_scalar(array, rhs.elem()) + }) + } + + fn float_sub(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { + execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::sub) + } + + fn float_sub_scalar(lhs: FloatTensor, rhs: Scalar) -> FloatTensor { + execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray| { + NdArrayMathOps::sub_scalar(array, rhs.elem()) + }) + } + + fn float_mul(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { + execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::mul) + } + + fn float_mul_scalar(lhs: FloatTensor, rhs: Scalar) -> FloatTensor { + execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray| { + NdArrayMathOps::mul_scalar(array, rhs.elem()) + }) + } + + fn float_div(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { + execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::div) + } + + fn float_div_scalar(lhs: FloatTensor, rhs: Scalar) -> FloatTensor { + execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray| { + NdArrayMathOps::div_scalar(array, rhs.elem()) + }) + } + + fn float_remainder(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { + execute_with_float_dtype!((lhs, rhs), NdArrayMathOps::remainder) + } + + fn float_remainder_scalar(lhs: FloatTensor, rhs: Scalar) -> FloatTensor { + execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray| { + NdArrayMathOps::remainder_scalar(array, rhs.elem()) + }) + } + + fn float_matmul(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { + execute_with_float_dtype!((lhs, rhs), matmul) + } + + fn float_cross( + lhs: FloatTensor, + rhs: FloatTensor, + dim: usize, + ) -> FloatTensor { + execute_with_float_dtype!((lhs, rhs), |lhs, rhs| cross(lhs, rhs, dim)) + } + + fn float_recip(tensor: FloatTensor) -> FloatTensor { + execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { + NdArrayMathOps::recip(array) + }) + } + + fn float_swap_dims(tensor: FloatTensor, dim1: usize, dim2: usize) -> FloatTensor { + execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { + NdArrayOps::swap_dims(array, dim1, dim2) + }) + } + + fn float_reshape(tensor: FloatTensor, shape: Shape) -> FloatTensor { + execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { + NdArrayOps::reshape(array, shape) + }) + } + + fn float_gather( + dim: usize, + tensor: FloatTensor, + indices: NdArrayTensor, + ) -> FloatTensor { + execute_with_int_dtype!( + indices, + IntElem, + |idx_array: SharedArray| -> NdArrayTensor { + execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { + NdArrayOps::gather(dim, array, idx_array) + }) + } + ) + } + + fn float_scatter_add( + dim: usize, + tensor: FloatTensor, + indices: NdArrayTensor, + value: FloatTensor, + ) -> FloatTensor { + execute_with_int_dtype!( + indices, + IntElem, + |idx_array: SharedArray| -> NdArrayTensor { + execute_with_float_dtype!((tensor, value), |tensor, value| NdArrayOps::scatter( + dim, tensor, idx_array, value + )) + } + ) + } + + fn float_select( + tensor: FloatTensor, + dim: usize, + indices: NdArrayTensor, + ) -> FloatTensor { + execute_with_int_dtype!( + indices, + IntElem, + |idx_array: SharedArray| -> NdArrayTensor { + execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { + NdArrayMathOps::select(array, dim, idx_array) + }) + } + ) + } + + fn float_select_add( + tensor: FloatTensor, + dim: usize, + indices: NdArrayTensor, + value: FloatTensor, + ) -> FloatTensor { + execute_with_int_dtype!( + indices, + IntElem, + |idx_array: SharedArray| -> NdArrayTensor { + execute_with_float_dtype!((tensor, value), |tensor, value| { + NdArrayMathOps::select_assign(tensor, dim, idx_array, value) + }) + } + ) + } + + fn float_slice(tensor: FloatTensor, slices: &[burn_backend::Slice]) -> FloatTensor { + slice!(tensor, slices) + } + + fn float_slice_assign( + tensor: FloatTensor, + slices: &[burn_backend::Slice], + value: FloatTensor, + ) -> FloatTensor { + execute_with_float_dtype!((tensor, value), |tensor, value| { + NdArrayOps::slice_assign(tensor, slices, value) + }) + } + + fn float_mask_where( + tensor: FloatTensor, + mask: NdArrayTensor, + value: FloatTensor, + ) -> FloatTensor { + execute_with_float_dtype!((tensor, value), |tensor, value| { + NdArrayOps::mask_where(tensor, mask.bool(), value) + }) + } + + fn float_mask_fill( + tensor: FloatTensor, + mask: NdArrayTensor, + value: Scalar, + ) -> FloatTensor { + execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { + NdArrayOps::mask_fill(array, mask.bool(), value.elem()) + }) + } + + fn float_equal( + lhs: FloatTensor, + rhs: FloatTensor, + _out_dtype: BoolDType, + ) -> NdArrayTensor { + execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { NdArrayMathOps::equal(lhs, rhs) }) + } + + fn float_equal_elem( + lhs: FloatTensor, + rhs: Scalar, + _out_dtype: BoolDType, + ) -> NdArrayTensor { + execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray| { + NdArrayMathOps::equal_elem(array, rhs.elem()) + }) + } + + fn float_greater( + lhs: FloatTensor, + rhs: FloatTensor, + _out_dtype: BoolDType, + ) -> NdArrayTensor { + execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { NdArrayMathOps::greater(lhs, rhs) }) + } + + fn float_greater_elem( + lhs: FloatTensor, + rhs: Scalar, + _out_dtype: BoolDType, + ) -> NdArrayTensor { + execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray| { + NdArrayMathOps::greater_elem(array, rhs.elem()) + }) + } + + fn float_greater_equal( + lhs: FloatTensor, + rhs: FloatTensor, + _out_dtype: BoolDType, + ) -> NdArrayTensor { + execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { + NdArrayMathOps::greater_equal(lhs, rhs) + }) + } + + fn float_greater_equal_elem( + lhs: FloatTensor, + rhs: Scalar, + _out_dtype: BoolDType, + ) -> NdArrayTensor { + execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray| { + NdArrayMathOps::greater_equal_elem(array, rhs.elem()) + }) + } + + fn float_lower( + lhs: FloatTensor, + rhs: FloatTensor, + _out_dtype: BoolDType, + ) -> NdArrayTensor { + execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { NdArrayMathOps::lower(lhs, rhs) }) + } + + fn float_lower_elem( + lhs: FloatTensor, + rhs: Scalar, + _out_dtype: BoolDType, + ) -> NdArrayTensor { + execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray| { + NdArrayMathOps::lower_elem(array, rhs.elem()) + }) + } + + fn float_lower_equal( + lhs: FloatTensor, + rhs: FloatTensor, + _out_dtype: BoolDType, + ) -> NdArrayTensor { + execute_with_float_dtype!((lhs, rhs), |lhs, rhs| { + NdArrayMathOps::lower_equal(lhs, rhs) + }) + } + + fn float_lower_equal_elem( + lhs: FloatTensor, + rhs: Scalar, + _out_dtype: BoolDType, + ) -> NdArrayTensor { + execute_with_float_dtype!(lhs, FloatElem, |array: SharedArray| { + NdArrayMathOps::lower_equal_elem(array, rhs.elem()) + }) + } + + fn float_detach(tensor: FloatTensor) -> FloatTensor { + tensor + } + + fn float_mean(tensor: FloatTensor) -> FloatTensor { + // Use view() for zero-copy on borrowed storage + execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { + NdArrayMathOps::mean_view(array.view()) + }) + } + + fn float_sum(tensor: FloatTensor) -> FloatTensor { + // Use view() for zero-copy on borrowed storage + execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { + NdArrayMathOps::sum_view(array.view()) + }) + } + + fn float_mean_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { + execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { + NdArrayMathOps::mean_dim(array, dim) + }) + } + + fn float_cumsum(tensor: FloatTensor, dim: usize) -> FloatTensor { + execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { + NdArrayMathOps::cumsum(array, dim) + }) + } + + fn float_cumprod(tensor: FloatTensor, dim: usize) -> FloatTensor { + execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { + NdArrayMathOps::cumprod(array, dim) + }) + } + + fn float_cummin(tensor: FloatTensor, dim: usize) -> FloatTensor { + execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { + NdArrayMathOps::cummin(array, dim) + }) + } + + fn float_cummax(tensor: FloatTensor, dim: usize) -> FloatTensor { + execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { + NdArrayMathOps::cummax(array, dim) + }) + } + + fn float_sum_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { + execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { + NdArrayMathOps::sum_dim(array, dim) + }) + } + + fn float_argmax(tensor: FloatTensor, dim: usize, out_dtype: IntDType) -> NdArrayTensor { + // Use view() for zero-copy on borrowed storage + execute_with_int_out_dtype!(out_dtype, I, { + execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { + NdArrayMathOps::argmax_view::(array.view(), dim) + }) + }) + } + + fn float_argmin(tensor: FloatTensor, dim: usize, out_dtype: IntDType) -> NdArrayTensor { + // Use view() for zero-copy on borrowed storage + execute_with_int_out_dtype!(out_dtype, I, { + execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { + NdArrayMathOps::argmin_view::(array.view(), dim) + }) + }) + } + + fn float_exp(tensor: FloatTensor) -> FloatTensor { + // Fast path: contiguous f32 → ndarray::hpc::vml::vsexp (F32x16 SIMD). + // Falls through to scalar mapv_into for non-f32 or non-contiguous. + #[cfg(feature = "simd")] + let tensor = match try_vml_unary(tensor, ndarray::hpc::vml::vsexp) { + Ok(result) => return result, + Err(t) => t, + }; + execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { + array.mapv_into(|a: FloatElem| a.exp_elem()).into_shared() + }) + } + + fn float_log(tensor: FloatTensor) -> FloatTensor { + #[cfg(feature = "simd")] + let tensor = match try_vml_unary(tensor, ndarray::hpc::vml::vsln) { + Ok(result) => return result, + Err(t) => t, + }; + execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { + array.mapv_into(|a: FloatElem| a.log_elem()).into_shared() + }) + } + + fn float_prod(tensor: FloatTensor) -> FloatTensor { + // Use view() for zero-copy on borrowed storage + execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { + NdArrayMathOps::prod_view(array.view()) + }) + } + + fn float_prod_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { + execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { + NdArrayMathOps::prod_dim(array, dim) + }) + } + + fn float_max(tensor: FloatTensor) -> FloatTensor { + // Use view() for zero-copy on borrowed storage + execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { + NdArrayMathOps::max_view(array.view()) + }) + } + + fn float_min(tensor: FloatTensor) -> FloatTensor { + // Use view() for zero-copy on borrowed storage + execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { + NdArrayMathOps::min_view(array.view()) + }) + } + + fn float_log1p(tensor: FloatTensor) -> FloatTensor { + execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { + array.mapv_into(|a: FloatElem| a.log1p_elem()).into_shared() + }) + } + + fn float_powf_scalar_impl(tensor: FloatTensor, value: Scalar) -> FloatTensor { + execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { + array + .mapv_into(|a: FloatElem| a.powf_elem(value.elem())) + .into_shared() + }) + } + + fn float_sqrt(tensor: FloatTensor) -> FloatTensor { + #[cfg(feature = "simd")] + let tensor = match try_vml_unary(tensor, ndarray::hpc::vml::vssqrt) { + Ok(result) => return result, + Err(t) => t, + }; + execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { + array.mapv_into(|a: FloatElem| a.sqrt_elem()).into_shared() + }) + } + + fn float_abs(tensor: FloatTensor) -> FloatTensor { + #[cfg(feature = "simd")] + let tensor = match try_vml_unary(tensor, ndarray::hpc::vml::vsabs) { + Ok(result) => return result, + Err(t) => t, + }; + execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { + NdArrayMathOps::abs(array) + }) + } + + fn float_cos(tensor: FloatTensor) -> FloatTensor { + #[cfg(feature = "simd")] + let tensor = match try_vml_unary(tensor, ndarray::hpc::vml::vscos) { + Ok(result) => return result, + Err(t) => t, + }; + execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { + array + .mapv_into(|a: FloatElem| (a.to_f64()).cos().elem()) + .into_shared() + }) + } + + fn float_cosh(tensor: FloatTensor) -> FloatTensor { + execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { + array + .mapv_into(|a: FloatElem| (a.to_f64()).cosh().elem()) + .into_shared() + }) + } + + fn float_sin(tensor: FloatTensor) -> FloatTensor { + #[cfg(feature = "simd")] + let tensor = match try_vml_unary(tensor, ndarray::hpc::vml::vssin) { + Ok(result) => return result, + Err(t) => t, + }; + execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { + array + .mapv_into(|a: FloatElem| (a.to_f64()).sin().elem()) + .into_shared() + }) + } + + fn float_sinh(tensor: FloatTensor) -> FloatTensor { + execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { + array + .mapv_into(|a: FloatElem| (a.to_f64()).sinh().elem()) + .into_shared() + }) + } + + fn float_tan(tensor: FloatTensor) -> FloatTensor { + execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { + array + .mapv_into(|a: FloatElem| (a.to_f64()).tan().elem()) + .into_shared() + }) + } + + fn float_tanh(tensor: FloatTensor) -> FloatTensor { + execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { + array + .mapv_into(|a: FloatElem| (a.to_f64()).tanh().elem()) + .into_shared() + }) + } + + fn float_acos(tensor: FloatTensor) -> FloatTensor { + execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { + array + .mapv_into(|a: FloatElem| (a.to_f64()).acos().elem()) + .into_shared() + }) + } + + fn float_acosh(tensor: FloatTensor) -> FloatTensor { + execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { + array + .mapv_into(|a: FloatElem| (a.to_f64()).acosh().elem()) + .into_shared() + }) + } + + fn float_asin(tensor: FloatTensor) -> FloatTensor { + execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { + array + .mapv_into(|a: FloatElem| (a.to_f64()).asin().elem()) + .into_shared() + }) + } + + fn float_asinh(tensor: FloatTensor) -> FloatTensor { + execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { + array + .mapv_into(|a: FloatElem| (a.to_f64()).asinh().elem()) + .into_shared() + }) + } + + fn float_atan(tensor: FloatTensor) -> FloatTensor { + execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { + array + .mapv_into(|a: FloatElem| (a.to_f64()).atan().elem()) + .into_shared() + }) + } + + fn float_atanh(tensor: FloatTensor) -> FloatTensor { + execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { + array + .mapv_into(|a: FloatElem| (a.to_f64()).atanh().elem()) + .into_shared() + }) + } + + fn float_atan2(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { + execute_with_float_dtype!((lhs, rhs), FloatElem, |lhs, rhs| { + NdArrayMathOps::elementwise_op(lhs, rhs, |a: &FloatElem, b: &FloatElem| a.atan2(*b)) + }) + } + + fn float_round(tensor: FloatTensor) -> FloatTensor { + execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { + array + .mapv_into(|a: FloatElem| round_ties_even_wrapper(a.to_f64()).elem()) + .into_shared() + }) + } + + fn float_floor(tensor: FloatTensor) -> FloatTensor { + execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { + array + .mapv_into(|a: FloatElem| (a.to_f64()).floor().elem()) + .into_shared() + }) + } + + fn float_ceil(tensor: FloatTensor) -> FloatTensor { + execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { + array + .mapv_into(|a: FloatElem| (a.to_f64()).ceil().elem()) + .into_shared() + }) + } + + fn float_trunc(tensor: FloatTensor) -> FloatTensor { + execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { + array + .mapv_into(|a: FloatElem| (a.to_f64()).trunc().elem()) + .into_shared() + }) + } + + fn float_erf(tensor: FloatTensor) -> FloatTensor { + execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { + array + .mapv_into(|a: FloatElem| erf(a.to_f64()).elem()) + .into_shared() + }) + } + + fn float_cat(tensors: Vec>, dim: usize) -> FloatTensor { + cat_with_dtype!(tensors, dim, [F64, F32]) + } + + fn float_clamp_min(tensor: FloatTensor, min: Scalar) -> FloatTensor { + execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { + NdArrayMathOps::clamp_min(array, min.elem()) + }) + } + + fn float_clamp_max(tensor: FloatTensor, max: Scalar) -> FloatTensor { + execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { + NdArrayMathOps::clamp_max(array, max.elem()) + }) + } + + fn float_clamp(tensor: FloatTensor, min: Scalar, max: Scalar) -> FloatTensor { + execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { + NdArrayMathOps::clamp(array, min.elem(), max.elem()) + }) + } + + fn float_into_int(tensor: FloatTensor, out_dtype: IntDType) -> NdArrayTensor { + execute_with_int_out_dtype!(out_dtype, I, { + execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { + array.mapv(|a: FloatElem| a.elem::()).into_shared() + }) + }) + } + + fn float_powf(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { + execute_with_float_dtype!((lhs, rhs), FloatElem, |lhs, rhs| { + NdArrayMathOps::elementwise_op(lhs, rhs, |a: &FloatElem, b: &FloatElem| a.powf(*b)) + }) + } + + fn float_permute(tensor: FloatTensor, axes: &[usize]) -> FloatTensor { + execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { + NdArrayOps::permute(array, axes) + }) + } + + fn float_flip(tensor: FloatTensor, axes: &[usize]) -> FloatTensor { + execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { + NdArrayOps::flip(array, axes) + }) + } + + fn float_sign(tensor: FloatTensor) -> FloatTensor { + execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { + NdArrayMathOps::sign_op(array) + }) + } + + fn float_expand(tensor: FloatTensor, shape: Shape) -> FloatTensor { + execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { + NdArrayOps::expand(array, shape) + }) + } + + fn float_cast(tensor: FloatTensor, dtype: FloatDType) -> FloatTensor { + execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { + cast_to_dtype(array, dtype.into()) + }) + } + + fn float_grid_sample_2d( + tensor: FloatTensor, + grid: FloatTensor, + options: GridSampleOptions, + ) -> FloatTensor { + execute_with_float_dtype!((tensor, grid), |tensor, grid| grid_sample_2d( + tensor, grid, options + )) + } + + fn float_unfold( + tensor: FloatTensor, + dim: usize, + size: usize, + step: usize, + ) -> FloatTensor { + execute_with_float_dtype!(tensor, FloatElem, |array: SharedArray| { + NdArrayOps::unfold(array, dim, size, step) + }) + } +} diff --git a/crates/burn/src/ops/transaction.rs b/crates/burn/src/ops/transaction.rs new file mode 100644 index 00000000..b308c0f0 --- /dev/null +++ b/crates/burn/src/ops/transaction.rs @@ -0,0 +1,13 @@ +use crate::{ + FloatNdArrayElement, NdArray, NdArrayTensor, SharedArray, + element::{IntNdArrayElement, QuantElement}, +}; +use burn_backend::ops::TransactionOps; + +impl TransactionOps + for NdArray +where + NdArrayTensor: From>, + NdArrayTensor: From>, +{ +} diff --git a/crates/burn/src/parallel.rs b/crates/burn/src/parallel.rs new file mode 100644 index 00000000..a6657619 --- /dev/null +++ b/crates/burn/src/parallel.rs @@ -0,0 +1,76 @@ +/// Macro for running a function in parallel. +#[cfg(feature = "multi-threads")] +#[macro_export(local_inner_macros)] +macro_rules! run_par { + ( + $func:expr + ) => {{ + use rayon::prelude::*; + + #[allow(clippy::redundant_closure_call)] + rayon::scope(|_| $func()) + }}; +} + +/// Macro for running a function in parallel. +#[cfg(not(feature = "multi-threads"))] +#[macro_export(local_inner_macros)] +macro_rules! run_par { + ( + $func:expr + ) => {{ $func() }}; +} + +/// Macro for iterating in parallel. +#[cfg(not(feature = "multi-threads"))] +#[macro_export(local_inner_macros)] +macro_rules! iter_par { + ( + $iter:expr + ) => {{ $iter }}; +} + +/// Macro for iterating in parallel. +#[cfg(feature = "multi-threads")] +#[macro_export(local_inner_macros)] +macro_rules! iter_par { + ( + $iter:expr + ) => {{ $iter.into_par_iter() }}; +} + +/// Macro for iterating in parallel. +#[cfg(feature = "multi-threads")] +#[macro_export(local_inner_macros)] +macro_rules! iter_slice_par { + ( + $slice:expr + ) => {{ $slice.into_par_iter() }}; +} + +/// Macro for iterating in parallel. +#[cfg(not(feature = "multi-threads"))] +#[macro_export(local_inner_macros)] +macro_rules! iter_slice_par { + ( + $slice:expr + ) => {{ $slice.iter() }}; +} + +/// Macro for iterating over a range in parallel. +#[cfg(feature = "multi-threads")] +#[macro_export(local_inner_macros)] +macro_rules! iter_range_par { + ( + $start:expr, $end:expr + ) => {{ ($start..$end).into_par_iter() }}; +} + +/// Macro for iterating over a range in parallel. +#[cfg(not(feature = "multi-threads"))] +#[macro_export(local_inner_macros)] +macro_rules! iter_range_par { + ( + $start:expr, $end:expr + ) => {{ ($start..$end) }}; +} diff --git a/crates/burn/src/rand.rs b/crates/burn/src/rand.rs new file mode 100644 index 00000000..94b9bcda --- /dev/null +++ b/crates/burn/src/rand.rs @@ -0,0 +1,36 @@ +//! Random number generation utilities for burn-ndarray + +#[cfg(not(feature = "std"))] +use rand::rngs::SmallRng; +#[cfg(feature = "std")] +use rand::rngs::StdRng; + +/// Type alias for the RNG used by burn-ndarray +#[cfg(feature = "std")] +pub type NdArrayRng = StdRng; +#[cfg(not(feature = "std"))] +pub type NdArrayRng = SmallRng; + +#[cfg(not(feature = "std"))] +use rand::SeedableRng; + +/// Get a seeded random number generator +/// +/// For std builds, uses OS entropy. +/// For no_std builds, uses a compile-time random seed. +#[cfg(feature = "std")] +pub fn get_seeded_rng() -> NdArrayRng { + // Use the standard implementation from burn-std + burn_std::rand::get_seeded_rng() +} + +/// Get a seeded random number generator +/// +/// For std builds, uses OS entropy. +/// For no_std builds, uses a compile-time random seed. +#[cfg(not(feature = "std"))] +pub fn get_seeded_rng() -> NdArrayRng { + // Use compile-time random seed for no_std + const SEED: u64 = const_random::const_random!(u64); + SmallRng::seed_from_u64(SEED) +} diff --git a/crates/burn/src/sharing.rs b/crates/burn/src/sharing.rs new file mode 100644 index 00000000..75d51421 --- /dev/null +++ b/crates/burn/src/sharing.rs @@ -0,0 +1,19 @@ +use core::cell::UnsafeCell; + +/// Similar to `SyncUnsafeCell` see [Rust issues](https://github.com/rust-lang/rust/issues/95439). +pub(crate) struct UnsafeSharedRef<'a, T> { + cell: UnsafeCell<&'a mut T>, +} + +unsafe impl Sync for UnsafeSharedRef<'_, T> {} + +impl<'a, T> UnsafeSharedRef<'a, T> { + pub fn new(data: &'a mut T) -> Self { + Self { + cell: UnsafeCell::new(data), + } + } + pub unsafe fn get(&self) -> &'a mut T { + unsafe { core::ptr::read(self.cell.get()) } + } +} diff --git a/crates/burn/src/storage.rs b/crates/burn/src/storage.rs new file mode 100644 index 00000000..7eeca47f --- /dev/null +++ b/crates/burn/src/storage.rs @@ -0,0 +1,506 @@ +//! Copy-on-write storage for zero-copy tensor loading. +//! +//! This module provides `NdArrayStorage`, which enables true zero-copy loading +//! from burnpack files. When data is borrowed from external memory (like mmap'd files +//! or static data), it remains zero-copy until a mutating operation is performed, +//! at which point it's copied (copy-on-write semantics). +//! +//! This integrates with ndarray's existing COW patterns - operations that check +//! `is_unique()` will see borrowed data as non-unique, triggering the allocation path. + +use burn_backend::Element; +use burn_std::{Bytes, Shape}; +use core::mem; +use ndarray::{ArcArray, ArrayView, IxDyn}; + +/// Storage that supports both owned data and borrowed (zero-copy) data. +/// +/// # Copy-on-Write Semantics +/// +/// - **Borrowed**: Data from external source (burnpack, mmap, static). +/// Reports `is_unique() == false` to trigger copy on mutation. +/// - **Owned**: Standard `ArcArray` with built-in COW via Arc refcount. +/// +/// # Example +/// +/// ```ignore +/// // Zero-copy load +/// let storage = NdArrayStorage::from_borrowed(bytes, shape); +/// storage.is_unique(); // false - will copy on mutation +/// +/// // Read operations use view() - zero-copy +/// let view = storage.view(); +/// +/// // Mutation converts to owned +/// let owned = storage.into_owned(); // Copies here +/// ``` +#[derive(Debug)] +pub enum NdArrayStorage { + /// Borrowed from external source (e.g., burnpack zero-copy load). + /// Keeps `Bytes` alive to ensure the referenced memory is valid. + Borrowed { + /// Source bytes - keeps external memory alive via reference counting + bytes: Bytes, + /// Shape of the tensor + shape: Shape, + }, + + /// Standard owned storage with ArcArray COW semantics. + Owned(ArcArray), +} + +impl Clone for NdArrayStorage { + fn clone(&self) -> Self { + match self { + // For borrowed data, clone the Bytes (cheap Arc clone) and shape + Self::Borrowed { bytes, shape } => Self::Borrowed { + bytes: bytes.clone(), + shape: shape.clone(), + }, + // For owned data, clone the ArcArray (cheap Arc clone) + Self::Owned(arr) => Self::Owned(arr.clone()), + } + } +} + +impl NdArrayStorage { + /// Create borrowed storage from external bytes. + /// + /// Returns the bytes and shape back on failure (misaligned or too small), + /// enabling zero-copy even for native allocations by avoiding defensive cloning. + /// + /// # Requirements + /// + /// The caller must ensure that: + /// - The `Bytes` contain valid data for the element type `E` + /// - The data is contiguous in row-major (C) order matching the provided shape + /// + /// These requirements are upheld when loading from `TensorData` (burnpack, etc.) + /// which always stores data contiguously in row-major order. + pub fn from_borrowed(bytes: Bytes, shape: impl Into) -> Result { + let shape = shape.into(); + // Validate alignment + let ptr = bytes.as_ptr(); + if !(ptr as usize).is_multiple_of(mem::align_of::()) { + return Err((bytes, shape)); + } + + // Validate size (using checked arithmetic to prevent overflow) + let num_elements = match shape + .iter() + .try_fold(1usize, |acc, &dim| acc.checked_mul(dim)) + { + Some(n) => n, + None => return Err((bytes, shape)), + }; + let expected_size = match num_elements.checked_mul(mem::size_of::()) { + Some(s) => s, + None => return Err((bytes, shape)), + }; + if bytes.len() < expected_size { + return Err((bytes, shape)); + } + + Ok(Self::Borrowed { bytes, shape }) + } + + /// Create owned storage from an ArcArray. + #[inline] + pub fn from_owned(array: ArcArray) -> Self { + Self::Owned(array) + } + + /// Returns whether this storage is uniquely owned and can be mutated in-place. + /// + /// - **Borrowed**: Always returns `false` to trigger copy-on-write. + /// - **Owned**: Delegates to `ArcArray::is_unique()`. + /// + /// This integrates with existing SIMD code patterns like: + /// ```ignore + /// if tensor.is_unique() { + /// // mutate in place + /// } else { + /// // allocate new + /// } + /// ``` + #[inline] + pub fn is_unique(&self) -> bool { + match self { + Self::Borrowed { .. } => false, // Force copy path + Self::Owned(arr) => arr.is_unique(), + } + } + + /// Get a read-only view of the data. + /// + /// This is zero-copy for both borrowed and owned variants. + #[inline] + pub fn view(&self) -> ArrayView<'_, E, IxDyn> { + match self { + Self::Borrowed { bytes, shape } => { + let ptr = bytes.as_ptr() as *const E; + let dim = IxDyn(shape); + // SAFETY: + // - `bytes` is kept alive for the lifetime of `self` + // - Alignment was validated in `from_borrowed` + // - Size was validated in `from_borrowed` + unsafe { ArrayView::from_shape_ptr(dim, ptr) } + } + Self::Owned(arr) => arr.view(), + } + } + + /// Convert to owned ArcArray. + /// + /// - **Borrowed**: Copies the data into a new ArcArray. + /// - **Owned + unique**: Returns the array without copying. + /// - **Owned + shared**: Clones the data. + pub fn into_owned(self) -> ArcArray { + match self { + Self::Borrowed { bytes, shape } => { + let ptr = bytes.as_ptr() as *const E; + let dim = IxDyn(&shape); + // SAFETY: Same as view() - bytes is valid for this scope + let view = unsafe { ArrayView::from_shape_ptr(dim, ptr) }; + view.to_owned().into_shared() + } + Self::Owned(arr) => arr, + } + } + + /// Convert to shared ArcArray, suitable for returning from operations. + /// + /// This is equivalent to `into_owned()` but named for clarity. + #[inline] + pub fn into_shared(self) -> ArcArray { + self.into_owned() + } + + /// Get the shape of the tensor. + pub fn shape(&self) -> &[usize] { + match self { + Self::Borrowed { shape, .. } => shape, + Self::Owned(arr) => arr.shape(), + } + } + + /// Get the number of dimensions. + #[inline] + pub fn ndim(&self) -> usize { + self.shape().len() + } + + /// Get the total number of elements. + #[inline] + pub fn len(&self) -> usize { + self.shape().iter().product() + } + + /// Check if the tensor is empty. + #[inline] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Returns `true` if this is borrowed (zero-copy) storage. + #[inline] + pub fn is_borrowed(&self) -> bool { + matches!(self, Self::Borrowed { .. }) + } + + /// Returns `true` if this is owned storage. + #[inline] + pub fn is_owned(&self) -> bool { + matches!(self, Self::Owned(_)) + } + + /// Ensure owned and return mutable reference to the ArcArray. + /// + /// Converts borrowed to owned if necessary. + pub fn ensure_owned(&mut self) -> &mut ArcArray { + if let Self::Borrowed { bytes, shape } = self { + let ptr = bytes.as_ptr() as *const E; + let dim = IxDyn(shape); + // SAFETY: Same as view() + let view = unsafe { ArrayView::from_shape_ptr(dim, ptr) }; + *self = Self::Owned(view.to_owned().into_shared()); + } + match self { + Self::Owned(arr) => arr, + Self::Borrowed { .. } => unreachable!(), + } + } +} + +/// Convert from ArcArray to NdArrayStorage. +impl From> for NdArrayStorage { + fn from(array: ArcArray) -> Self { + Self::Owned(array) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use alloc::{vec, vec::Vec}; + use burn_std::Bytes; + + #[test] + fn test_borrowed_is_not_unique() { + let data: Vec = vec![1.0, 2.0, 3.0, 4.0]; + let bytes = Bytes::from_elems(data); + let storage = NdArrayStorage::::from_borrowed(bytes, [2, 2]).expect("should create"); + + assert!(!storage.is_unique()); + assert!(storage.is_borrowed()); + } + + #[test] + fn test_owned_unique_when_single_ref() { + let array = ndarray::ArrayD::from_elem(IxDyn(&[2, 2]), 1.0f32).into_shared(); + let storage = NdArrayStorage::from_owned(array); + + assert!(storage.is_unique()); + assert!(storage.is_owned()); + } + + #[test] + fn test_owned_not_unique_when_cloned() { + let array = ndarray::ArrayD::from_elem(IxDyn(&[2, 2]), 1.0f32).into_shared(); + let storage = NdArrayStorage::from_owned(array); + let _clone = storage.clone(); + + assert!(!storage.is_unique()); + } + + #[test] + fn test_view_zero_copy() { + let data: Vec = vec![1.0, 2.0, 3.0, 4.0]; + let bytes = Bytes::from_elems(data); + let storage = NdArrayStorage::::from_borrowed(bytes, [2, 2]).expect("should create"); + + let view = storage.view(); + assert_eq!(view[[0, 0]], 1.0); + assert_eq!(view[[1, 1]], 4.0); + } + + #[test] + fn test_into_owned_copies_borrowed() { + let data: Vec = vec![1.0, 2.0, 3.0, 4.0]; + let bytes = Bytes::from_elems(data); + let storage = NdArrayStorage::::from_borrowed(bytes, [2, 2]).expect("should create"); + + let owned = storage.into_owned(); + assert_eq!(owned[[0, 0]], 1.0); + assert_eq!(owned[[1, 1]], 4.0); + } + + #[test] + fn test_from_borrowed_validates_alignment() { + use burn_std::AllocationProperty; + + // Test 1: Properly aligned data should succeed + let aligned_data: Vec = vec![1.0, 2.0, 3.0, 4.0]; + let aligned_bytes = Bytes::from_elems(aligned_data); + + // Verify test setup - should be 4-byte aligned for f32 + assert_eq!( + (aligned_bytes.as_ptr() as usize) % core::mem::align_of::(), + 0, + "Test setup: f32 data should be properly aligned" + ); + + let result = NdArrayStorage::::from_borrowed(aligned_bytes, [2, 2]); + assert!( + result.is_ok(), + "from_borrowed should succeed for properly aligned data" + ); + + // Test 2: Misaligned data should fail + // Create a buffer large enough to find a misaligned offset + // (static data placement varies by platform, so we find an offset dynamically) + let buffer: &[u8] = &[0u8; 32]; + let shared = bytes::Bytes::from_static(buffer); + let base = shared.as_ptr() as usize; + let align = core::mem::align_of::(); + + // Find an offset in 1..align that produces misalignment (at least one must exist) + let misalign_offset = (1..align) + .find(|&off| !(base + off).is_multiple_of(align)) + .expect("Should find a misaligned offset"); + + let sliced = shared.slice(misalign_offset..(misalign_offset + 16)); + let misaligned_bytes = Bytes::from_shared(sliced, AllocationProperty::Other); + + // Verify test setup - should NOT be 4-byte aligned + assert_ne!( + (misaligned_bytes.as_ptr() as usize) % align, + 0, + "Test setup: sliced data should be misaligned for f32" + ); + + let result = NdArrayStorage::::from_borrowed(misaligned_bytes, [4]); + assert!( + result.is_err(), + "from_borrowed should return Err for misaligned data" + ); + } + + #[test] + fn test_insufficient_size_returns_err() { + // Create bytes that are too small for the requested shape + let data: Vec = vec![1.0, 2.0]; // 8 bytes + let bytes = Bytes::from_elems(data); + + // Try to create storage for 4 elements (needs 16 bytes) + let result = NdArrayStorage::::from_borrowed(bytes, [4]); + assert!( + result.is_err(), + "from_borrowed should return Err when bytes are too small" + ); + } + + // ========================================================================== + // Zero-copy hardening tests + // These tests verify the zero-copy guarantee is maintained. If any of these + // fail, it indicates a regression in zero-copy functionality. + // ========================================================================== + + #[test] + fn test_zero_copy_native_allocation() { + // CRITICAL: Verify that native allocations (Bytes::from_elems) are zero-copy + // on initial load. The view() must return a pointer to the SAME memory. + // + // Note: Native allocations copy on clone (this is expected), but the initial + // load is still zero-copy, avoiding an extra copy in the common case where + // the tensor is used without cloning. + let data: Vec = vec![1.0, 2.0, 3.0, 4.0]; + let bytes = Bytes::from_elems(data); + let original_ptr = bytes.as_ptr(); + + let storage = NdArrayStorage::::from_borrowed(bytes, [2, 2]).expect("should create"); + + // Initial load must be zero-copy + let view = storage.view(); + let view_ptr = view.as_ptr() as *const u8; + + assert_eq!( + original_ptr, view_ptr, + "ZERO-COPY REGRESSION: native allocation view() must return pointer to original bytes" + ); + + // Verify data integrity + assert_eq!(view[[0, 0]], 1.0); + assert_eq!(view[[0, 1]], 2.0); + assert_eq!(view[[1, 0]], 3.0); + assert_eq!(view[[1, 1]], 4.0); + } + + #[test] + fn test_zero_copy_shared_bytes_pointer_identity() { + // CRITICAL: Test with SharedBytesAllocationController for true zero-copy. + // This simulates the actual burnpack/mmap loading path. + use burn_std::AllocationProperty; + + // Create static-like data using bytes::Bytes + let data: &[u8] = &[ + 0, 0, 128, 63, // 1.0f32 in little-endian + 0, 0, 0, 64, // 2.0f32 + 0, 0, 64, 64, // 3.0f32 + 0, 0, 128, 64, // 4.0f32 + ]; + let shared = bytes::Bytes::from_static(data); + let original_ptr = shared.as_ptr(); + + // Create Bytes with SharedBytesAllocationController + let bytes = Bytes::from_shared(shared, AllocationProperty::Other); + + let storage = NdArrayStorage::::from_borrowed(bytes, [2, 2]).expect("should create"); + + // Verify pointer identity + let view_ptr = storage.view().as_ptr() as *const u8; + assert_eq!( + original_ptr, view_ptr, + "ZERO-COPY REGRESSION: SharedBytes view must point to original static data" + ); + + // Clone should also share the same memory + let cloned = storage.clone(); + let cloned_ptr = cloned.view().as_ptr() as *const u8; + assert_eq!( + original_ptr, cloned_ptr, + "ZERO-COPY REGRESSION: SharedBytes clone must share memory" + ); + } + + #[test] + fn test_clone_borrowed_stays_borrowed() { + // Verify that cloning borrowed storage produces another borrowed storage. + // Note: The underlying Bytes may or may not share memory depending on + // the allocation controller (native allocations copy, file-backed may share). + let data: Vec = vec![1.0, 2.0, 3.0, 4.0]; + let bytes = Bytes::from_elems(data); + + let storage = NdArrayStorage::::from_borrowed(bytes, [2, 2]).expect("should create"); + let cloned = storage.clone(); + + // Both should still be borrowed (the storage type is preserved) + assert!( + storage.is_borrowed(), + "ZERO-COPY REGRESSION: original should remain borrowed after clone" + ); + assert!( + cloned.is_borrowed(), + "ZERO-COPY REGRESSION: clone should be borrowed type" + ); + + // Both should report not unique (important for COW behavior) + assert!( + !storage.is_unique(), + "ZERO-COPY REGRESSION: original should not be unique after clone" + ); + assert!( + !cloned.is_unique(), + "ZERO-COPY REGRESSION: clone should not be unique" + ); + + // Data should be identical + assert_eq!(storage.view(), cloned.view(), "Clone should have same data"); + } + + #[test] + fn test_zero_copy_triggers_copy_on_mutation() { + // Verify that into_owned() on borrowed data creates a NEW allocation + // (this is the "copy" in copy-on-write) + let data: Vec = vec![1.0, 2.0, 3.0, 4.0]; + let bytes = Bytes::from_elems(data); + let original_ptr = bytes.as_ptr(); + + let storage = NdArrayStorage::::from_borrowed(bytes, [2, 2]).expect("should create"); + + assert!(storage.is_borrowed(), "should start as borrowed"); + + let owned = storage.into_owned(); + let owned_ptr = owned.as_ptr() as *const u8; + + assert_ne!( + original_ptr, owned_ptr, + "into_owned() on borrowed data MUST allocate new memory (copy-on-write)" + ); + } + + #[test] + fn test_borrowed_reports_not_unique() { + // CRITICAL: Borrowed storage must report is_unique() == false + // This is what triggers copy-on-write in mutation operations + let data: Vec = vec![1.0, 2.0, 3.0, 4.0]; + let bytes = Bytes::from_elems(data); + let storage = NdArrayStorage::::from_borrowed(bytes, [2, 2]).expect("should create"); + + assert!( + !storage.is_unique(), + "ZERO-COPY REGRESSION: borrowed storage MUST report is_unique() == false \ + to trigger copy-on-write. If this is true, mutations will corrupt shared data!" + ); + } +} diff --git a/crates/burn/src/tensor.rs b/crates/burn/src/tensor.rs new file mode 100644 index 00000000..97699a1f --- /dev/null +++ b/crates/burn/src/tensor.rs @@ -0,0 +1,955 @@ +use burn_backend::{ + AllocationProperty, DType, Element, QTensorPrimitive, Shape, TensorData, TensorMetadata, + quantization::{QParams, QuantLevel, QuantMode, QuantScheme, QuantValue}, +}; +use burn_std::BoolStore; + +use crate::NdArrayStorage; +use crate::ops::quantization::{QuantizationStrategy, SymmetricQuantization}; +use alloc::vec::Vec; +use ndarray::{ArcArray, ArrayD, IxDyn}; + +/// Concrete storage type for ndarray (owned with COW semantics via Arc) +pub type SharedArray = ArcArray; + +/// Tensor primitive used by the [ndarray backend](crate::NdArray). +/// +/// Supports both owned and borrowed (zero-copy) data via `NdArrayStorage`. +/// When data is borrowed from external sources (like burnpack files), +/// it remains zero-copy until a mutating operation is performed. +#[derive(Debug, Clone)] +#[allow(missing_docs)] +pub enum NdArrayTensor { + F64(NdArrayStorage), + F32(NdArrayStorage), + I64(NdArrayStorage), + I32(NdArrayStorage), + I16(NdArrayStorage), + I8(NdArrayStorage), + U64(NdArrayStorage), + U32(NdArrayStorage), + U16(NdArrayStorage), + U8(NdArrayStorage), + Bool(NdArrayStorage), +} + +impl NdArrayTensor { + /// Extract bool array, converting to owned if necessary. + pub(crate) fn bool(self) -> SharedArray { + match self { + NdArrayTensor::Bool(storage) => storage.into_shared(), + _ => unimplemented!("Expected bool tensor, got {:?}", self.dtype()), + } + } + + /// Returns true if this tensor uses borrowed (zero-copy) storage. + #[inline] + pub fn is_borrowed(&self) -> bool { + macro_rules! check { + ($($variant:ident),*) => { + match self { + $(NdArrayTensor::$variant(s) => s.is_borrowed(),)* + } + }; + } + check!(F64, F32, I64, I32, I16, I8, U64, U32, U16, U8, Bool) + } +} + +pub(crate) fn cast_to_dtype(array: SharedArray, dtype: DType) -> NdArrayTensor +where + NdArrayTensor: From>, +{ + fn cast(array: SharedArray) -> SharedArray { + array.mapv(|a| a.elem()).into_shared() + } + + if E1::dtype() == dtype { + return array.into(); + } + + match dtype { + DType::F64 => cast::(array).into(), + DType::F32 => cast::(array).into(), + DType::Flex32 => cast::(array).into(), + DType::I64 => cast::(array).into(), + DType::I32 => cast::(array).into(), + DType::I16 => cast::(array).into(), + DType::I8 => cast::(array).into(), + DType::U64 => cast::(array).into(), + DType::U32 => cast::(array).into(), + DType::U16 => cast::(array).into(), + DType::U8 => cast::(array).into(), + DType::Bool(BoolStore::Native) => cast::(array).into(), + dtype => panic!("Unsupported dtype: {dtype:?}"), + } +} + +macro_rules! impl_from { + ($($ty: ty => $dtype: ident),*) => { + // From SharedArray (owned) -> NdArrayTensor + $(impl From> for NdArrayTensor { + fn from(value: SharedArray<$ty>) -> NdArrayTensor { + NdArrayTensor::$dtype(NdArrayStorage::from_owned(value)) + } + })* + + // From NdArrayStorage -> NdArrayTensor + $(impl From> for NdArrayTensor { + fn from(value: NdArrayStorage<$ty>) -> NdArrayTensor { + NdArrayTensor::$dtype(value) + } + })* + }; +} + +impl_from!( + f64 => F64, f32 => F32, + i64 => I64, i32 => I32, i16 => I16, i8 => I8, + u64 => U64, u32 => U32, u16 => U16, u8 => U8, + bool => Bool +); + +/// Macro to execute an operation on a given element type. +/// +/// Extracts the storage from NdArrayTensor, converts to SharedArray, and passes to operation. +/// +/// # Panics +/// Since there is no automatic type cast at this time, binary operations for different +/// floating point precision data types will panic with a data type mismatch. +#[macro_export] +macro_rules! execute_with_dtype { + (($lhs:expr, $rhs:expr),$element:ident, $op:expr, [$($dtype: ident => $ty: ty),*]) => {{ + let lhs_dtype = burn_backend::TensorMetadata::dtype(&$lhs); + let rhs_dtype = burn_backend::TensorMetadata::dtype(&$rhs); + match ($lhs, $rhs) { + $( + ($crate::NdArrayTensor::$dtype(lhs), $crate::NdArrayTensor::$dtype(rhs)) => { + #[allow(unused)] + type $element = $ty; + // Convert storage to SharedArray for compatibility with existing operations + $op(lhs.into_shared(), rhs.into_shared()).into() + } + )* + _ => panic!( + "Data type mismatch (lhs: {:?}, rhs: {:?})", + lhs_dtype, rhs_dtype + ), + } + }}; + // Binary op: type automatically inferred by the compiler + (($lhs:expr, $rhs:expr), $op:expr) => {{ + $crate::execute_with_dtype!(($lhs, $rhs), E, $op) + }}; + + // Binary op: generic type cannot be inferred for an operation + (($lhs:expr, $rhs:expr), $element:ident, $op:expr) => {{ + $crate::execute_with_dtype!(($lhs, $rhs), $element, $op, [ + F64 => f64, F32 => f32, + I64 => i64, I32 => i32, I16 => i16, I8 => i8, + U64 => u64, U32 => u32, U16 => u16, U8 => u8, + Bool => bool + ]) + }}; + + ($tensor:expr, $element:ident, $op:expr, [$($dtype: ident => $ty: ty),*]) => {{ + match $tensor { + $( + $crate::NdArrayTensor::$dtype(storage) => { + #[allow(unused)] + type $element = $ty; + // Convert to SharedArray for compatibility with most operations + $op(storage.into_shared()).into() + } + )* + #[allow(unreachable_patterns)] + other => unimplemented!("unsupported dtype: {:?}", other.dtype()) + } + }}; + // Unary op: type automatically inferred by the compiler + ($tensor:expr, $op:expr) => {{ + $crate::execute_with_dtype!($tensor, E, $op) + }}; + + // Unary op: generic type cannot be inferred for an operation + ($tensor:expr, $element:ident, $op:expr) => {{ + $crate::execute_with_dtype!($tensor, $element, $op, [ + F64 => f64, F32 => f32, + I64 => i64, I32 => i32, I16 => i16, I8 => i8, + U64 => u64, U32 => u32, U16 => u16, U8 => u8, + Bool => bool + ]) + }}; +} + +/// Macro to execute an operation a given element type. +/// Only handles float types. +/// +/// # Panics +/// Since there is no automatic type cast at this time, binary operations for different +/// floating point precision data types will panic with a data type mismatch. +#[macro_export] +macro_rules! execute_with_float_dtype { + // Binary op: type automatically inferred by the compiler + (($lhs:expr, $rhs:expr), $op:expr) => {{ + $crate::execute_with_float_dtype!(($lhs, $rhs), E, $op) + }}; + + // Binary op: generic type cannot be inferred for an operation + (($lhs:expr, $rhs:expr), $element:ident, $op:expr) => {{ + $crate::execute_with_dtype!(($lhs, $rhs), $element, $op, [ + F64 => f64, F32 => f32 + ]) + }}; + + // Unary op: type automatically inferred by the compiler + ($tensor:expr, $op:expr) => {{ + $crate::execute_with_float_dtype!($tensor, E, $op) + }}; + + // Unary op: generic type cannot be inferred for an operation + ($tensor:expr, $element:ident, $op:expr) => {{ + $crate::execute_with_dtype!($tensor, $element, $op, [ + F64 => f64, F32 => f32 + ]) + }}; +} + +/// Macro to execute an operation a given element type. +/// Only handles int types. +/// +/// # Panics +/// Since there is no automatic type cast at this time, binary operations for different +/// floating point precision data types will panic with a data type mismatch. +#[macro_export] +macro_rules! execute_with_int_dtype { + // Binary op: type automatically inferred by the compiler + (($lhs:expr, $rhs:expr), $op:expr) => {{ + $crate::execute_with_int_dtype!(($lhs, $rhs), E, $op) + }}; + + // Binary op: generic type cannot be inferred for an operation + (($lhs:expr, $rhs:expr), $element:ident, $op:expr) => {{ + $crate::execute_with_dtype!(($lhs, $rhs), $element, $op, [ + I64 => i64, I32 => i32, I16 => i16, I8 => i8, + U64 => u64, U32 => u32, U16 => u16, U8 => u8 + ]) + }}; + + // Unary op: type automatically inferred by the compiler + ($tensor:expr, $op:expr) => {{ + $crate::execute_with_int_dtype!($tensor, E, $op) + }}; + + // Unary op: generic type cannot be inferred for an operation + ($tensor:expr, $element:ident, $op:expr) => {{ + $crate::execute_with_dtype!($tensor, $element, $op, [ + I64 => i64, I32 => i32, I16 => i16, I8 => i8, + U64 => u64, U32 => u32, U16 => u16, U8 => u8 + ]) + }}; +} + +/// Macro to execute an operation a given element type. +/// Only handles numeric types +/// +/// # Panics +/// Since there is no automatic type cast at this time, binary operations for different +/// floating point precision data types will panic with a data type mismatch. +#[macro_export] +macro_rules! execute_with_numeric_dtype { + // Binary op: type automatically inferred by the compiler + (($lhs:expr, $rhs:expr), $op:expr) => {{ + $crate::execute_with_numeric_dtype!(($lhs, $rhs), E, $op) + }}; + + // Binary op: generic type cannot be inferred for an operation + (($lhs:expr, $rhs:expr), $element:ident, $op:expr) => {{ + $crate::execute_with_dtype!(($lhs, $rhs), $element, $op, [ + F64 => f64, F32 => f32, + I64 => i64, I32 => i32, I16 => i16, I8 => i8, + U64 => u64, U32 => u32, U16 => u16, U8 => u8 + ]) + }}; + + // Unary op: type automatically inferred by the compiler + ($tensor:expr, $op:expr) => {{ + $crate::execute_with_numeric_dtype!($tensor, E, $op) + }}; + + // Unary op: generic type cannot be inferred for an operation + ($tensor:expr, $element:ident, $op:expr) => {{ + $crate::execute_with_dtype!($tensor, $element, $op, [ + F64 => f64, F32 => f32, + I64 => i64, I32 => i32, I16 => i16, I8 => i8, + U64 => u64, U32 => u32, U16 => u16, U8 => u8 + ]) + }}; +} + +/// Macro to execute a cat operation on a given set of element types. +/// +/// Uses zero-copy views from storage for concatenation. +/// +/// # Panics +/// Since there is no automatic type cast at this time, binary operations for different +/// floating point precision data types will panic with a data type mismatch. +#[macro_export] +macro_rules! cat_with_dtype { + ($tensors: expr, $dim: expr, [$($dtype: ident),*]) => { + match &$tensors[0] { + $(NdArrayTensor::$dtype(_) => { + let tensors = $tensors + .iter() + .map(|t| { + if let NdArrayTensor::$dtype(storage) = t { + // Use storage.view() for zero-copy access + storage.view() + } else { + panic!("Concatenate data type mismatch (expected {:?}, got {:?})", $tensors[0].dtype(), t.dtype()) + } + }) + .collect::>(); + NdArrayOps::concatenate(&tensors, $dim).into() + })* + _ => panic!("Unsupported dtype: {:?}", $tensors[0].dtype()) + } + }; +} + +/// Macro to execute an operation that returns a given element type. +#[macro_export] +macro_rules! execute_with_float_out_dtype { + ($out_dtype:expr, $element:ident, $op:expr, [$($dtype: ident => $ty: ty),*]) => {{ + match $out_dtype { + $( + burn_std::FloatDType::$dtype => { + #[allow(unused)] + type $element = $ty; + $op + } + )* + #[allow(unreachable_patterns)] + other => unimplemented!("unsupported dtype: {other:?}") + } + }}; + // Unary op: type automatically inferred by the compiler + ($out_dtype:expr, $op:expr) => {{ + $crate::execute_with_float_out_dtype!($out_dtype, E, $op) + }}; + + // Unary op: generic type cannot be inferred for an operation + ($out_dtype:expr, $element:ident, $op:expr) => {{ + $crate::execute_with_float_out_dtype!($out_dtype, $element, $op, [ + F64 => f64, F32 => f32 + ]) + }}; +} + +/// Macro to execute an operation that returns a given element type. +#[macro_export] +macro_rules! execute_with_int_out_dtype { + ($out_dtype:expr, $element:ident, $op:expr, [$($dtype: ident => $ty: ty),*]) => {{ + match $out_dtype { + $( + burn_std::IntDType::$dtype => { + #[allow(unused)] + type $element = $ty; + $op + } + )* + #[allow(unreachable_patterns)] + other => unimplemented!("unsupported dtype: {other:?}") + } + }}; + // Unary op: type automatically inferred by the compiler + ($out_dtype:expr, $op:expr) => {{ + $crate::execute_with_int_out_dtype!($out_dtype, E, $op) + }}; + + // Unary op: generic type cannot be inferred for an operation + ($out_dtype:expr, $element:ident, $op:expr) => {{ + $crate::execute_with_int_out_dtype!($out_dtype, $element, $op, [ + I64 => i64, I32 => i32, I16 => i16, I8 => i8, + U64 => u64, U32 => u32, U16 => u16, U8 => u8 + ]) + }}; +} + +impl TensorMetadata for NdArrayTensor { + fn dtype(&self) -> DType { + match self { + NdArrayTensor::F64(_) => DType::F64, + NdArrayTensor::F32(_) => DType::F32, + NdArrayTensor::I64(_) => DType::I64, + NdArrayTensor::I32(_) => DType::I32, + NdArrayTensor::I16(_) => DType::I16, + NdArrayTensor::I8(_) => DType::I8, + NdArrayTensor::U64(_) => DType::U64, + NdArrayTensor::U32(_) => DType::U32, + NdArrayTensor::U16(_) => DType::U16, + NdArrayTensor::U8(_) => DType::U8, + NdArrayTensor::Bool(_) => DType::Bool(BoolStore::Native), + } + } + + fn shape(&self) -> Shape { + // Use storage's shape method (works for both borrowed and owned) + macro_rules! get_shape { + ($($variant:ident),*) => { + match self { + $(NdArrayTensor::$variant(storage) => Shape::from(storage.shape().to_vec()),)* + } + }; + } + get_shape!(F64, F32, I64, I32, I16, I8, U64, U32, U16, U8, Bool) + } + + fn rank(&self) -> usize { + self.shape().num_dims() + } +} + +pub(crate) trait ShapeOps { + fn num_dims(self) -> usize; + fn num_elements(self) -> usize; + fn dims(self) -> [usize; N]; + fn into_shape(self) -> Shape; +} + +impl ShapeOps for &[usize] { + fn num_dims(self) -> usize { + self.len() + } + + fn num_elements(self) -> usize { + self.iter().product() + } + + fn dims(self) -> [usize; N] { + self.try_into().unwrap() + } + + fn into_shape(self) -> Shape { + Shape::from(self) + } +} + +mod utils { + use burn_std::tensor::is_contiguous; + + use super::*; + + impl NdArrayTensor { + pub(crate) fn into_data(self) -> TensorData { + let shape = self.shape(); + let contiguous = self.is_contiguous(); + + fn inner( + shape: Shape, + is_contiguous: bool, + array: ArcArray, + ) -> TensorData { + let vec = if is_contiguous { + match array.try_into_owned_nocopy() { + Ok(owned) => { + let (mut vec, offset) = owned.into_raw_vec_and_offset(); + if let Some(offset) = offset { + vec.drain(..offset); + } + if vec.len() > shape.num_elements() { + vec.drain(shape.num_elements()..vec.len()); + } + vec + } + Err(array) => array.into_iter().collect(), + } + } else { + array.into_iter().collect() + }; + + TensorData::new(vec, shape) + } + + // Convert storage to owned array before extracting data + execute_with_dtype!(self, |arr| inner(shape, contiguous, arr)) + } + + pub(crate) fn is_contiguous(&self) -> bool { + // For borrowed data, we assume it's contiguous (it came from TensorData which is contiguous) + // For owned data, we check the strides + macro_rules! check_contiguous { + ($($variant:ident),*) => { + match self { + $(NdArrayTensor::$variant(storage) => { + match storage { + NdArrayStorage::Borrowed { .. } => { + // Borrowed storage requires contiguous row-major data + // (see NdArrayStorage::from_borrowed documentation) + true + } + NdArrayStorage::Owned(array) => { + let shape = array.shape(); + let mut strides = Vec::with_capacity(array.strides().len()); + for &stride in array.strides() { + if stride <= 0 { + return false; + } + strides.push(stride as usize); + } + is_contiguous(shape, &strides) + } + } + })* + } + }; + } + check_contiguous!(F64, F32, I64, I32, I16, I8, U64, U32, U16, U8, Bool) + } + } +} + +/// Converts a slice of usize to a typed dimension. +#[macro_export(local_inner_macros)] +macro_rules! to_typed_dims { + ( + $n:expr, + $dims:expr, + justdim + ) => {{ + let mut dims = [0; $n]; + for i in 0..$n { + dims[i] = $dims[i]; + } + let dim: Dim<[usize; $n]> = Dim(dims); + dim + }}; +} + +/// Reshapes an array into a tensor. +#[macro_export(local_inner_macros)] +macro_rules! reshape { + ( + ty $ty:ty, + n $n:expr, + shape $shape:expr, + array $array:expr + ) => {{ + let dim = $crate::to_typed_dims!($n, $shape, justdim); + let array = match $array.is_standard_layout() { + true => { + match $array.to_shape(dim) { + Ok(val) => val.into_shared(), + Err(err) => { + core::panic!("Shape should be compatible shape={dim:?}: {err:?}"); + } + } + }, + false => $array.to_shape(dim).unwrap().as_standard_layout().into_shared(), + }; + array.into_dyn() + }}; + ( + ty $ty:ty, + shape $shape:expr, + array $array:expr, + d $D:expr + ) => {{ + match $D { + 1 => reshape!(ty $ty, n 1, shape $shape, array $array), + 2 => reshape!(ty $ty, n 2, shape $shape, array $array), + 3 => reshape!(ty $ty, n 3, shape $shape, array $array), + 4 => reshape!(ty $ty, n 4, shape $shape, array $array), + 5 => reshape!(ty $ty, n 5, shape $shape, array $array), + 6 => reshape!(ty $ty, n 6, shape $shape, array $array), + _ => core::panic!("NdArray supports arrays up to 6 dimensions, received: {}", $D), + } + }}; +} + +/// Slice a tensor +#[macro_export] +macro_rules! slice { + ($tensor:expr, $slices:expr) => { + slice!($tensor, $slices, F64, F32, I64, I32, I16, I8, U64, U32, U16, U8, Bool) + }; + ($tensor:expr, $slices:expr, $($variant:ident),*) => { + match $tensor { + $(NdArrayTensor::$variant(s) => { NdArrayOps::slice(s.view(), $slices).into() })* + } + }; +} + +impl NdArrayTensor { + /// Create a new [ndarray tensor](NdArrayTensor) from [data](TensorData). + /// + /// This method attempts zero-copy loading when possible. If the data has properly + /// aligned bytes that can be borrowed, it creates a borrowed tensor. Otherwise, + /// it falls back to copying the data. + /// + /// Zero-copy loading works when: + /// - The data's bytes are properly aligned for the element type + /// - The bytes can be borrowed (e.g., from mmap'd file or static data) + pub fn from_data(data: TensorData) -> NdArrayTensor { + // Only use Borrowed storage for non-native allocations (e.g., burnpack mmap/file). + // For native Rust heap allocations (the common case), go directly to owned storage: + // `from_data_owned` reclaims the Vec zero-copy via `into_vec`, while + // Borrowed storage would trigger a full memcopy on every single operation. + if data.bytes.property() != AllocationProperty::Native { + match Self::try_from_data_borrowed(data) { + Ok(tensor) => return tensor, + Err(data) => return Self::from_data_owned(data), + } + } + Self::from_data_owned(data) + } + + /// Try to create a tensor with borrowed storage (zero-copy). + /// + /// Takes ownership of TensorData and returns it back on failure. + /// No cloning occurs - bytes are moved into storage or returned on failure. + /// + /// Returns `Err(data)` if borrowing is not possible (e.g., misaligned data). + fn try_from_data_borrowed(data: TensorData) -> Result { + let TensorData { + bytes, + shape, + dtype, + } = data; + + macro_rules! try_borrow { + ($ty:ty, $variant:ident, $bytes:expr, $shape:expr) => { + match NdArrayStorage::<$ty>::from_borrowed($bytes, $shape) { + Ok(storage) => return Ok(NdArrayTensor::$variant(storage)), + Err((bytes, shape)) => (bytes, shape), + } + }; + } + + // Try to create borrowed storage; get bytes back on failure + let (bytes, shape) = match dtype { + DType::F64 => try_borrow!(f64, F64, bytes, shape), + DType::F32 => try_borrow!(f32, F32, bytes, shape), + DType::I64 => try_borrow!(i64, I64, bytes, shape), + DType::I32 => try_borrow!(i32, I32, bytes, shape), + DType::I16 => try_borrow!(i16, I16, bytes, shape), + DType::I8 => try_borrow!(i8, I8, bytes, shape), + DType::U64 => try_borrow!(u64, U64, bytes, shape), + DType::U32 => try_borrow!(u32, U32, bytes, shape), + DType::U16 => try_borrow!(u16, U16, bytes, shape), + DType::U8 => try_borrow!(u8, U8, bytes, shape), + DType::Bool(BoolStore::Native) => try_borrow!(bool, Bool, bytes, shape), + _ => (bytes, shape), // QFloat not supported for zero-copy + }; + + Err(TensorData { + bytes, + shape, + dtype, + }) + } + + /// Create a tensor with owned storage. + /// + /// This may or may not copy data depending on whether the underlying bytes + /// can be reclaimed (via `try_into_vec`). If bytes are uniquely owned, + /// no copy occurs; otherwise data is copied to a new allocation. + fn from_data_owned(data: TensorData) -> NdArrayTensor { + let shape = data.shape.to_vec(); // TODO: into_vec + + macro_rules! execute { + ($data: expr, [$($dtype: pat => $ty: ty),*]) => { + match $data.dtype { + $( $dtype => { + match data.into_vec::<$ty>() { + Ok(vec) => unsafe { ArrayD::from_shape_vec_unchecked(shape, vec) }.into_shared(), + Err(err) => panic!("Data should have the same element type as the tensor {err:?}"), + }.into() + }, )* + other => unimplemented!("Unsupported dtype {other:?}"), + } + }; + } + + execute!(data, [ + DType::F64 => f64, DType::F32 => f32, + DType::I64 => i64, DType::I32 => i32, DType::I16 => i16, DType::I8 => i8, + DType::U64 => u64, DType::U32 => u32, DType::U16 => u16, DType::U8 => u8, + DType::Bool(BoolStore::Native) => bool + ]) + } +} + +/// A quantized tensor for the ndarray backend. +#[derive(Clone, Debug)] +pub struct NdArrayQTensor { + /// The quantized tensor. + pub qtensor: NdArrayTensor, + /// The quantization scheme. + pub scheme: QuantScheme, + /// The quantization parameters. + pub qparams: Vec>, +} + +impl NdArrayQTensor { + /// Returns the quantization strategy, including quantization parameters, for the given tensor. + pub fn strategy(&self) -> QuantizationStrategy { + match self.scheme { + QuantScheme { + level: QuantLevel::Tensor, + mode: QuantMode::Symmetric, + value: + QuantValue::Q8F + | QuantValue::Q8S + | QuantValue::E4M3 + | QuantValue::E5M2 + | QuantValue::Q4F + | QuantValue::Q4S + | QuantValue::E2M1 + | QuantValue::Q2F + | QuantValue::Q2S, + .. + } => QuantizationStrategy::PerTensorSymmetric(SymmetricQuantization::init( + self.qparams[0].scales, + self.scheme.value, + )), + QuantScheme { + level: QuantLevel::Block(block_size), + mode: QuantMode::Symmetric, + value: + QuantValue::Q8F + | QuantValue::Q8S + | QuantValue::E4M3 + | QuantValue::E5M2 + | QuantValue::Q4F + | QuantValue::Q4S + | QuantValue::E2M1 + | QuantValue::Q2F + | QuantValue::Q2S, + .. + } => QuantizationStrategy::PerBlockSymmetric( + self.qparams + .iter() + .map(|q| SymmetricQuantization::init(q.scales, self.scheme.value)) + .collect(), + block_size, + ), + } + } +} + +impl QTensorPrimitive for NdArrayQTensor { + fn scheme(&self) -> &QuantScheme { + &self.scheme + } + + fn default_scheme() -> QuantScheme { + QuantScheme::default().with_store(burn_backend::quantization::QuantStore::Native) + } +} + +impl TensorMetadata for NdArrayQTensor { + fn dtype(&self) -> DType { + DType::QFloat(self.scheme) + } + + fn shape(&self) -> Shape { + self.qtensor.shape() + } + + fn rank(&self) -> usize { + self.shape().num_dims() + } +} + +#[cfg(test)] +mod tests { + use crate::NdArray; + use alloc::vec; + + use super::*; + use burn_backend::{ + Distribution, + ops::{FloatTensorOps, QTensorOps}, + quantization::{QuantStore, QuantizationParametersPrimitive}, + }; + use burn_std::rand::get_seeded_rng; + + #[test] + fn should_support_into_and_from_data_1d() { + let data_expected = TensorData::random::( + Shape::new([3]), + Distribution::Default, + &mut get_seeded_rng(), + ); + let tensor = NdArrayTensor::from_data(data_expected.clone()); + + let data_actual = tensor.into_data(); + + assert_eq!(data_expected, data_actual); + } + + #[test] + fn should_support_into_and_from_data_2d() { + let data_expected = TensorData::random::( + Shape::new([2, 3]), + Distribution::Default, + &mut get_seeded_rng(), + ); + let tensor = NdArrayTensor::from_data(data_expected.clone()); + + let data_actual = tensor.into_data(); + + assert_eq!(data_expected, data_actual); + } + + #[test] + fn should_support_into_and_from_data_3d() { + let data_expected = TensorData::random::( + Shape::new([2, 3, 4]), + Distribution::Default, + &mut get_seeded_rng(), + ); + let tensor = NdArrayTensor::from_data(data_expected.clone()); + + let data_actual = tensor.into_data(); + + assert_eq!(data_expected, data_actual); + } + + #[test] + fn should_support_into_and_from_data_4d() { + let data_expected = TensorData::random::( + Shape::new([2, 3, 4, 2]), + Distribution::Default, + &mut get_seeded_rng(), + ); + let tensor = NdArrayTensor::from_data(data_expected.clone()); + + let data_actual = tensor.into_data(); + + assert_eq!(data_expected, data_actual); + } + + #[test] + fn should_support_qtensor_strategy() { + type B = NdArray; + let scale: f32 = 0.009_019_608; + let device = Default::default(); + + let tensor = B::float_from_data(TensorData::from([-1.8f32, -1.0, 0.0, 0.5]), &device); + let scheme = QuantScheme::default() + .with_value(QuantValue::Q8S) + .with_store(QuantStore::Native); + let qparams = QuantizationParametersPrimitive { + scales: B::float_from_data(TensorData::from([scale]), &device), + }; + let qtensor: NdArrayQTensor = B::quantize(tensor, &scheme, qparams); + + assert_eq!(qtensor.scheme(), &scheme); + assert_eq!( + qtensor.strategy(), + QuantizationStrategy::PerTensorSymmetric(SymmetricQuantization::init( + scale, + QuantValue::Q8S + )) + ); + } + + // ========================================================================== + // Zero-copy integration tests + // These tests verify end-to-end zero-copy behavior through NdArrayTensor. + // ========================================================================== + + #[test] + fn zero_copy_creates_borrowed_storage_for_non_native() { + // Verify that from_data creates borrowed storage for non-native allocations + // (e.g. burnpack mmap/file data tagged with AllocationProperty::Other or File). + // Native heap allocations intentionally use Owned storage for performance. + use burn_backend::AllocationProperty; + use burn_std::Bytes; + + let data: Vec = vec![1.0, 2.0, 3.0, 4.0]; + let bytes = Bytes::from_elems(data); + // Tag as Other to simulate burnpack / mmap data (non-native backing storage) + let non_native_bytes = Bytes::from_shared( + bytes::Bytes::copy_from_slice(&bytes), + AllocationProperty::Other, + ); + let tensor_data = TensorData::from_bytes(non_native_bytes, Shape::new([2, 2]), DType::F32); + + let tensor = NdArrayTensor::from_data(tensor_data); + + match &tensor { + NdArrayTensor::F32(storage) => { + assert!( + storage.is_borrowed(), + "ZERO-COPY REGRESSION: from_data should create borrowed storage \ + for non-native (e.g. burnpack) TensorData" + ); + assert!( + !storage.is_unique(), + "ZERO-COPY REGRESSION: borrowed storage must report is_unique() == false" + ); + } + _ => panic!("Expected F32 tensor"), + } + } + + #[test] + fn native_alloc_creates_owned_storage() { + // Native heap allocations must use Owned storage to avoid the memcpy. + use burn_std::Bytes; + + let data: Vec = vec![1.0, 2.0, 3.0, 4.0]; + let bytes = Bytes::from_elems(data); // AllocationProperty::Native + let tensor_data = TensorData::from_bytes(bytes, Shape::new([2, 2]), DType::F32); + + let tensor = NdArrayTensor::from_data(tensor_data); + + match &tensor { + NdArrayTensor::F32(storage) => { + assert!( + !storage.is_borrowed(), + "PERF REGRESSION: from_data must NOT create borrowed storage \ + for native TensorData" + ); + } + _ => panic!("Expected F32 tensor"), + } + } + + #[test] + fn zero_copy_data_integrity() { + // Verify data is correctly accessible through borrowed storage + use burn_std::Bytes; + + let data: Vec = vec![1.0, 2.0, 3.0, 4.0]; + let bytes = Bytes::from_elems(data); + let tensor_data = TensorData::from_bytes(bytes, Shape::new([2, 2]), DType::F32); + + let tensor = NdArrayTensor::from_data(tensor_data); + + match &tensor { + NdArrayTensor::F32(storage) => { + let view = storage.view(); + assert_eq!(view[[0, 0]], 1.0); + assert_eq!(view[[0, 1]], 2.0); + assert_eq!(view[[1, 0]], 3.0); + assert_eq!(view[[1, 1]], 4.0); + } + _ => panic!("Expected F32 tensor"), + } + } + + #[test] + fn zero_copy_fallback_when_bytes_owned() { + // When TensorData owns bytes exclusively, it may use the copy path + // This is expected behavior - verify it still works correctly + let data = TensorData::from([1.0f32, 2.0, 3.0, 4.0]); + let tensor = NdArrayTensor::from_data(data.clone()); + let result = tensor.into_data(); + + assert_eq!(data, result, "Data should round-trip correctly"); + } +} diff --git a/rust-toolchain.toml b/rust-toolchain.toml new file mode 100644 index 00000000..76a06e6b --- /dev/null +++ b/rust-toolchain.toml @@ -0,0 +1,2 @@ +[toolchain] +channel = "1.94.0"