feat(simd): BF16x16 + F16x16 SIMD vectors + slice ops (sprint W3-A)#126
Conversation
…int W3-A) Closes parity items 2 + 3. Scalar dispatch (upcast f32 -> op -> downcast). SIMD-accelerated paths (AVX2 emulation, AVX-512-BF16 native, NEON +fp16) are a follow-up. The scalar implementation is correct and portable, and unblocks burn's NdArrayElement bound for half types. - src/simd_half.rs: 691 LOC new module - src/lib.rs: pub mod simd_half declaration - src/simd.rs: re-exports 21 new tests, all passing. Total lib tests: 1817+ pass. https://claude.ai/code/session_01NYGrxVopyszZYgLBxe4hgj
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 3358057a9c
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| #[cfg(all(feature = "std", not(all(target_arch = "x86_64", target_feature = "avx512bf16"))))] | ||
| pub use crate::simd_half::BF16x16 as BF16x16; |
There was a problem hiding this comment.
Keep BF16x16 API stable across target features
crate::simd::BF16x16 now resolves to two incompatible types depending on compile flags: this line hides the new portable simd_half::BF16x16 when target_feature="avx512bf16" is set, so AVX-512-BF16 builds get simd_avx512::BF16x16 (unsafe load/convert-only API) instead of the new arithmetic API (from_slice, add, mul, copy_to_slice). Any consumer code written against the newly introduced BF16x16 methods will compile on scalar/NEON/AVX2 targets and fail on AVX-512-BF16 targets, which breaks the cross-target SIMD dispatch parity this change is meant to provide.
Useful? React with 👍 / 👎.
Closes parity items (2)+(3): half-precision SIMD vector types so burn's
NdArrayElement::F16/BF16enum variants can dispatch through ndarray's SIMD layer.What ships:
src/simd_half.rs(691 LOC) —BF16x16andF16x16types, scalar dispatch (upcast f32 → op → downcast)add_bf16_inplace,mul_bf16_inplace,add_f16_inplace,mul_f16_inplace,cast_*_to_*_batch(8 helpers)src/simd.rsTests: 21 new, all passing. Total lib: 1817+ pass.
SIMD-accelerated paths (AVX2 emulation, AVX-512-BF16 native, NEON +fp16) are a follow-up. Scalar implementation is correct and portable — unblocks burn's
NdArrayElementbound for half types.https://claude.ai/code/session_01NYGrxVopyszZYgLBxe4hgj