diff --git a/.claude/board/AGENT_LOG.md b/.claude/board/AGENT_LOG.md index 27ac38ac..b6948deb 100644 --- a/.claude/board/AGENT_LOG.md +++ b/.claude/board/AGENT_LOG.md @@ -1178,3 +1178,349 @@ SIMD savings disappear below GPU baseline. integration candidate. The performance levers are GPU shader optimization + wgpu buffer bandwidth — outside ndarray's scope. + + +# ═══════════════════════════════════════════════════════════════════ +# Round 3-portable-simd — full 30-type coverage for crate::simd_nightly +# ═══════════════════════════════════════════════════════════════════ + +> **Branch:** `claude/portable-simd-nightly` +> **Goal:** expand `src/simd_nightly/` from 5-type draft (F32x16, F64x8, +> U8x64, U32x16, F32Mask16) to full 30-type coverage that mirrors the +> AVX-512 / AVX2 polyfill surface. Miri-runnable backend wrapping +> `core::simd::*`. +> **Fleet:** 12 Sonnet workers + 1 Sonnet meta. Same A2A pattern +> (`tee -a` to this file). +> **Permission:** the `.claude/settings.local.json` allow-list set up +> in round-2 still covers `tee -a /home/user/ndarray/.claude/board/AGENT_LOG.md`. + +## Fleet manifest (round 3-portable-simd) + +| # | Agent | Scope (file) | Types | +|---|---|---|---| +| 1 | f32-wrap | `src/simd_nightly/f32_types.rs` | F32x16, F32x8 | +| 2 | f64-wrap | `src/simd_nightly/f64_types.rs` | F64x8, F64x4 | +| 3 | u8-wrap | `src/simd_nightly/u8_types.rs` | U8x32, U8x64 | +| 4 | u-word-wrap | `src/simd_nightly/u_word_types.rs` | U16x32, U32x16, U64x8 | +| 5 | i8-wrap | `src/simd_nightly/i8_types.rs` | I8x32, I8x64 | +| 6 | i-word-wrap | `src/simd_nightly/i_word_types.rs` | I16x16, I16x32, I32x16, I64x8 | +| 7 | masks-wrap | `src/simd_nightly/masks.rs` | F32Mask16, F64Mask8 | +| 8 | bf16-emul | `src/simd_nightly/bf16_types.rs` | BF16x16, BF16x8 (scalar emulation — no `core::simd` half-prec) | +| 9 | f16-emul | `src/simd_nightly/f16_types.rs` | F16x16 (scalar emulation) | +| 10 | ops-macros | `src/simd_nightly/ops.rs` | Add/Sub/Mul/Div/BitAnd/BitOr/BitXor/Default macros applied to all types | +| 11 | exotic-fallbacks | `src/simd_nightly/exotic_methods.rs` | permute_bytes, shuffle_bytes scalar fallbacks for U8x32/U8x64 (`core::simd::swizzle` is const N — can't accept runtime idx vector) | +| 12 | parity-tests | `src/simd_nightly/tests.rs` | Comprehensive parity tests vs simd_avx512 / simd_avx2 references where they exist | +| M | meta-r3 | synthesis | Sonnet | + +## Round-3-portable-simd entries (newest first) + + +## 2026-05-13 — agent #9 f16-emul (sonnet-4-6) + +**File:** `src/simd_nightly/f16_types.rs` (220 lines) +**Status:** DONE + +- Replaced stub with full `F16x16([u16; 16])` scalar emulation. +- `LANES = 16`; constructors: `splat(f32)`, `from_slice(&[u16])`, `from_array`, `to_array`, `copy_to_slice`. +- Conversions: `to_f32_array`, `from_f32_array`. +- IEEE-754 binary16 logic copied verbatim from `src/hpc/quantized.rs` F16 methods (lines 193-301); cited in doc comments. +- `cargo check --features nightly-simd`: zero errors in `f16_types.rs`; 58 pre-existing errors in other simd_nightly files (masks.rs, ops.rs, etc.). + +## 2026-05-13T00:00 — agent #8 bf16-emul (sonnet) + +**File:** `src/simd_nightly/bf16_types.rs` (248 lines) +**Verdict:** PASS + +**Summary:** +- Implemented `BF16x16` and `BF16x8` as `#[repr(transparent)]` wrappers over `[u16; N]`. +- Methods: `splat(f32)`, `from_slice(&[u16])`, `from_array`, `to_array`, `copy_to_slice`, `to_f32_lossy() -> [f32; N]`, `from_f32_truncate([f32; N]) -> Self`, `LANES: usize`. +- Conversion helpers `f32_to_bf16_bits` (>> 16) and `bf16_bits_to_f32` (<< 16) are pure safe Rust. +- 12 unit tests cover splat roundtrip, truncate/expand, slice/array roundtrip, LANES const, and known bit patterns (1.0 = 0x3F80, -1.0 = 0xBF80). +- `rustup run nightly cargo check --features nightly-simd -p ndarray --lib`: zero errors in bf16_types.rs (pre-existing errors in other stub files owned by other agents). + +## 2026-05-13 — agent #6 i-word-wrap (sonnet-4-6) + +**File:** `src/simd_nightly/i_word_types.rs` (449 lines) +**Status:** DONE — `cargo check --features nightly-simd` passes clean + +**Work done:** +- Replaced stub with full implementations of `I16x16`, `I16x32`, `I32x16`, `I64x8` +- Each type: `LANES`, `splat`, `from_slice`, `from_array`, `to_array`, `copy_to_slice` +- Reductions: `reduce_sum` (wrapping), `reduce_min`, `reduce_max` via `SimdInt` +- Lane-wise: `simd_min`/`simd_max` via `SimdOrd` (added to imports alongside `SimdPartialOrd`) +- Compare→mask: `cmpeq_mask`/`cmpgt_mask` — `to_bitmask() as uN` (N = lane count: u16/u32/u16/u8) +- Saturating: `saturating_add`/`saturating_sub` on I16x16 and I16x32 only (I32/I64 have no sat ops in AVX-512 reference) +- `PartialEq` + `Display` impls; operator impls deferred to agent #10 + +## 2026-05-13T21:30 — agent #7 masks-wrap (sonnet) [backfilled by main] + +**File:** `src/simd_nightly/masks.rs` (196 lines) +**Status:** COMPILES (zero errors in this file) + +Implemented 4 mask wrapper structs: +- `F32Mask16(Mask)` — mirrors `simd_avx512::F32Mask16` +- `F32Mask8(Mask)` — for agents #1/#2 F32x8 cmp return +- `F64Mask8(Mask)` — mirrors `simd_avx512::F64Mask8` +- `F64Mask4(Mask)` — for agents #1/#2 F64x4 cmp return + +Per-struct methods: `to_bitmask() → uN` (with cast from u64), +`from_bitmask(bits: uN) → Self`, `select(true, false) → FloatType`, +`all() → bool`, `any() → bool`. + +**Key nightly-API finding:** `core::simd::Mask::to_bitmask()` ALWAYS +returns `u64` regardless of lane count; `from_bitmask()` ALWAYS takes +`u64`. The wrappers cast (`as u8` / `as u16` for narrower returns, +`bits as u64` for widening). The `select` method requires +`use core::simd::prelude::Select` in scope. + +`mod.rs` line 43 updated to expose all 4: `pub use masks::{F32Mask16, +F32Mask8, F64Mask8, F64Mask4};`. + + +## 2026-05-13T00:00 — agent #2 f64-wrap (sonnet) + +**File:** `src/simd_nightly/f64_types.rs` (307 lines) +**Verdict:** DONE + +**Types delivered:** `F64x8` (8-lane f64) and `F64x4` (4-lane f64). + +**Full API per type:** +- Constructors: `splat`, `from_slice`, `from_array`, `to_array`, `copy_to_slice` +- Reductions: `reduce_sum`, `reduce_min`, `reduce_max` +- Lane-wise: `simd_min`, `simd_max`, `simd_clamp` +- FMA + math: `mul_add`, `sqrt`, `round`, `floor`, `abs` +- Bits: `to_bits` → `U64x8` (F64x8) / `U64x4` (F64x4) +- Comparisons: `simd_eq/ne/lt/le/gt/ge` → `F64Mask8` / `F64Mask4` +- `LANES: usize` const + +**Key decisions:** +- `std::simd::StdFloat` required (not `core::simd::num::SimdFloat`) for `mul_add/sqrt/round/floor` — `core::simd::num::SimdFloat` only covers `reduce_*` and `simd_min/max`; StdFloat provides the FP math methods. +- Added `U64x4` and `U32x8` to `u_word_types.rs` as `F64x4::to_bits` and `F32x8::to_bits` companion types (agent #4 scope, but stubs were empty; noted in file header). +- Operator impls delegated to agent #10's `ops.rs` (already wired: `impl_fp_ops!(F64x8)` + `impl_fp_ops!(F64x4)`). + +**Cargo check:** `rustup run nightly cargo check --features nightly-simd -p ndarray --lib` → `Finished` (0 errors). + +## 2026-05-13T00:20 — agent #1 f32-wrap (sonnet) + +**File:** `src/simd_nightly/f32_types.rs` (395 lines) +**Types:** F32x16 (16 methods), F32x8 (16 methods) +**Status:** COMPILES + +**Notes / TODOs:** +- Both F32x16 and F32x8 implement: LANES const, splat, from_slice, from_array, to_array, copy_to_slice, reduce_sum, reduce_min, reduce_max, simd_min, simd_max, simd_clamp, mul_add, sqrt, round, floor, abs, to_bits, from_bits, simd_eq, simd_ne, simd_lt, simd_le, simd_gt, simd_ge. +- Key fix: `mul_add`, `sqrt`, `round`, `floor` require `std::simd::StdFloat` (NOT `core::simd::num::SimdFloat`). +- Also added `U32x8` struct to `u_word_types.rs` (required by F32x8::to_bits/from_bits); updated `mod.rs` to export `U32x8` and `U64x4`. +- `#![feature(portable_simd)]` must be enabled at crate root (lib.rs) for `std::simd::StdFloat` to exist; already present via nightly-simd feature. +- masks.rs (agent #7) and u_word_types.rs (agent #4) were already populated when this agent ran — no circular deps. +## 2026-05-13 — agent #3 u8-wrap (sonnet-4.6) + +**File:** `src/simd_nightly/u8_types.rs` (~830 lines) +**Status:** DONE — `cargo check --features nightly-simd` passes (0 errors from this file) + +**Implemented:** +- `pub struct U8x64(pub core::simd::u8x64)` + `pub struct U8x32(pub core::simd::u8x32)` +- Both: `LANES` const, `splat`, `from_slice`, `from_array`, `to_array`, `copy_to_slice` +- Both: `reduce_sum` (wrapping), `reduce_min`, `reduce_max`, `sum_bytes_u64` (u16 promotion) +- Both: `simd_min`, `simd_max` (required `SimdOrd` import in addition to `SimdPartialOrd`) +- Both: `saturating_add`, `saturating_sub` +- Both: `pairwise_avg` via `cast::()` promotion (no native avg in `core::simd`) +- Both: `cmpeq_mask`, `cmpgt_mask`, `movemask` — U8x64 → `u64`, U8x32 → `u32` (cast from `u64` since `to_bitmask()` always returns `u64`) +- Both: `shr_epi16`, `shl_epi16` via `transmute` to `[u16; N]` scalar loop +- Both: `nibble_popcount_lut()` as `from_array` with replicated 0,1,1,2,… pattern +- Both: `Default` → `splat(0)` +- 26 unit tests covering all methods + +**Decisions:** `nibble_popcount_lut` kept here (pure `from_array`, no shuffle dependency). `permute_bytes`, `shuffle_bytes`, `mask_blend`, `unpack_lo/hi_epi8` deferred to agent #11 (`exotic_methods.rs`) per spec. + +**Key finding:** `core::simd::Mask::to_bitmask()` returns `u64` for ALL lane widths including 32-lane vectors; U8x32 masks cast `as u32` to match AVX2 shape. + +## 2026-05-13T21:45 — agent #5 i8-wrap (sonnet) [backfilled by main] + +**File:** `src/simd_nightly/i8_types.rs` (263 lines) +**Status:** COMPILES (zero errors in this file) + +Implemented `I8x64(pub i8x64)` and `I8x32(pub i8x32)` — both +`#[repr(transparent)]`, `Copy + Clone + Debug + PartialEq`. + +Surface mirrors `simd_avx512.rs::I8x64` / `::I8x32`: +- Constructors: splat, from_slice, from_array, to_array, copy_to_slice +- Reductions: reduce_sum (wrapping), reduce_min, reduce_max +- Lane-wise: simd_min, simd_max +- Compare → mask: cmpeq_mask (u64 for I8x64, u32 for I8x32), cmpgt_mask + (native signed via `simd_gt`) +- Saturating: saturating_add, saturating_sub + +**Deviation from spec header:** added `SimdOrd` to imports alongside +`SimdPartialEq` / `SimdPartialOrd` — needed for `simd_min` / `simd_max` +to resolve on integer types in current nightly. + +## 2026-05-13T21:50 — agent #6 i-word-wrap (sonnet) [backfilled by main] + +**File:** `src/simd_nightly/i_word_types.rs` (449 lines) +**Status:** COMPILES (zero errors in this file) + +Implemented 4 wrappers: `I16x16`, `I16x32`, `I32x16`, `I64x8`. Each +`#[repr(transparent)]`, `Copy + Clone + Debug + PartialEq + Display`. + +Per-type surface: splat, from_slice, from_array, to_array, +copy_to_slice, reduce_sum (wrap), reduce_min, reduce_max, simd_min, +simd_max, cmpeq_mask, cmpgt_mask. + +`saturating_add` / `saturating_sub` added for I16 (matches AVX-512 +reference which provides them for i16 but not i32/i64). + +**Same SimdOrd finding as agent #5.** Also: bitmask cast `to_bitmask() +→ u64 as uN` for narrower mask shapes (u16 for 16-lane, u32 for 32-lane, +u8 for 8-lane). + + +## 2026-05-13T22:05 — agent #1 f32-wrap (sonnet) [backfilled by main] + +**File:** `src/simd_nightly/f32_types.rs` (395 lines) +**Status:** COMPILES (zero errors in this file) + +`F32x16(pub core::simd::f32x16)` + `F32x8(pub core::simd::f32x8)` with +full 16-method API per `simd_avx512.rs`: LANES, splat, from_slice, +from_array, to_array, copy_to_slice, reduce_sum/min/max, simd_min/max/ +clamp, mul_add, sqrt, round, floor, abs, to_bits (via +`super::u_word_types::{U32x16,U32x8}`), from_bits, simd_eq/ne/lt/le/gt/ +ge → `super::masks::{F32Mask16, F32Mask8}`. + +**Key nightly-API finding (echoed by agent #2 independently):** +`mul_add` / `sqrt` / `round` / `floor` require `use std::simd::StdFloat`, +NOT `core::simd::num::SimdFloat`. SimdFloat provides reduce/min/max/ +clamp but not the transcendentals. Worth folding into the +fleet-handover doc. + +Side effect: added `U32x8` to u_word_types.rs (agent #4 scope) + +re-exported from mod.rs. Necessary for F32x8::to_bits. + +Agent reports `cargo +nightly check --features nightly-simd` passes +crate-wide with zero errors at the moment of completion. Pending +remaining 3 agents. + +## 2026-05-13T22:08 — agent #2 f64-wrap (sonnet) [backfilled by main] + +**File:** `src/simd_nightly/f64_types.rs` (307 lines) +**Status:** COMPILES + +`F64x8(pub core::simd::f64x8)` + `F64x4(pub core::simd::f64x4)`. Same +shape as agent #1 at half width. Same `StdFloat` import requirement. + +Side effect: added `U64x4` + `U32x8` to u_word_types.rs (agent #4 +scope) for `F64x4::to_bits` and `F32x8::to_bits`. + + +## 2026-05-13T22:15 — agent #3 u8-wrap (sonnet) [backfilled by main] + +**File:** `src/simd_nightly/u8_types.rs` (~830 lines) +**Status:** COMPILES (zero errors in this file) + +`U8x64(pub core::simd::u8x64)` + `U8x32(pub core::simd::u8x32)` with +full method parity against `simd_avx512::U8x64` + `simd_avx2::U8x32` +(PR #144). + +Surface per type: +- Constructors: splat, from_slice, from_array, to_array, copy_to_slice +- Reductions: reduce_sum (wraps), reduce_min, reduce_max, + `sum_bytes_u64` (promotes to u16×N to avoid wrap) +- Lane-wise: simd_min, simd_max +- Saturating: saturating_add, saturating_sub +- Avg: `pairwise_avg` — promotes to u16, computes `(a+b+1)>>1`, casts + back to u8 (`core::simd` has no native `_mm512_avg_epu8` equivalent) +- Compare → mask: cmpeq_mask, cmpgt_mask, movemask + - U8x64 returns `u64`, U8x32 returns `u32` + - Cast from `u64` since `to_bitmask()` always returns u64 (per agents + #5, #6, #7 findings) +- Shifts: shr_epi16, shl_epi16 — reinterpret via `transmute` to + `[u16; N]`, scalar shift loop, transmute back +- `nibble_popcount_lut()` — kept HERE as a pure const-array + `from_array(...)`, no shuffle dep needed + +`Default` impl + 26 unit tests included in-file. + +**Same SimdOrd import finding** as agents #5, #6 — needed for +simd_min/simd_max on integer types. + + +## 2026-05-13T22:25 — agent #4 u-word-wrap (sonnet) [backfilled by main] + +**File:** `src/simd_nightly/u_word_types.rs` (~520 lines) +**Status:** COMPILES + +5 wrappers: `U16x32`, `U32x16`, `U32x8`, `U64x8`, `U64x4`. Per-type +surface: splat, from_slice, from_array, to_array, copy_to_slice, +reduce_sum/min/max, simd_min/max, cmpeq_mask, cmpgt_mask, Default. +U16x32 also has saturating_add/sub. + +**Mask widths:** cmpeq/cmpgt return u32 (32-lane), u16 (16-lane), u8 +(8-lane and 4-lane). Cast from u64 since `to_bitmask()` always returns +u64 (same finding as agents #5/#6/#7). + +**Same SimdOrd import finding** + `SimdPartialOrd` for cmpgt_mask. + +## 2026-05-13T22:30 — agent #10 ops-macros (sonnet) [backfilled by main] + +**File:** `src/simd_nightly/ops.rs` (265 lines) +**Status:** COMPILES + +3 macros: +- `impl_fp_ops!($T)` — Add/Sub/Mul/Div/Neg + 5 *Assign variants +- `impl_int_ops!($T)` — Add/Sub/BitAnd/BitOr/BitXor + 5 *Assign +- `impl_int_neg!($T)` — Neg only, applied to signed ints +- `impl_default!($T)` — `Self(Default::default())` + +Invocations cover: F32x16, F32x8, F64x8, F64x4, U8x32, U8x64, U16x32, +U32x16, U32x8, U64x8, U64x4, I8x32, I8x64, I16x16, I16x32, I32x16, +I64x8 — every concrete type defined by agents #1-#6. + +Floats use fp_ops; unsigned ints use int_ops only (no Neg); signed +ints get int_ops + int_neg. Default impls in this file OR in the +type-defining files — checked to avoid duplicates. + +## 2026-05-13T22:35 — agent #11 exotic-fallbacks (sonnet) [backfilled by main] + +**File:** `src/simd_nightly/exotic_methods.rs` (329 lines) +**Status:** COMPILES + +Extension `impl U8x64` / `impl U8x32` blocks (Rust allows multiple +impl-per-type within a crate) providing 5 methods `core::simd` lacks: + +- `permute_bytes(idx: Self) -> Self` — cross-lane scalar fallback, + idx masked `& 63` for U8x64 / `& 31` for U8x32 +- `shuffle_bytes(idx: Self) -> Self` — within-128-bit-lane; high bit + (0x80) zeroes the lane, low 4 bits index within 16-byte lane +- `mask_blend(mask: u64|u32, a, b) -> Self` — bitmask-driven select +- `unpack_lo_epi8(self, other)` / `unpack_hi_epi8(self, other)` — + per-128-bit-lane byte interleave + +`nibble_popcount_lut()` NOT duplicated here — agent #3 placed it in +u8_types.rs as a pure const-array `from_array(...)`. + +24 unit tests across all 10 new methods (5 per type). + +## 2026-05-13T22:40 — agent #12 parity-tests (sonnet) [backfilled by main] + +**File:** `src/simd_nightly/tests.rs` (76 new tests) +**Status:** ALL 76 PASS (`cargo +nightly test --features nightly-simd +-p ndarray --lib simd_nightly`: 153 total = 76 new + 77 pre-existing +from agent in-file tests; all pass) + +Coverage: +1. Constructor roundtrip — F32x16, F32x8, F64x8, F64x4 +2. Reduction parity (vs scalar fold) — all floats + U64x8/4, U32x16/8, U16x32 +3. Comparison mask parity — F32x16, F32x8, F64x8, F64x4, U8x32, U8x64 +4. Saturating arithmetic — U8x64, U8x32, U16x32 (max/min clamps) +5. FMA bit-exact — F32x16, F32x8, F64x8, F64x4 (`0.5.mul_add(2.0, 1.0) == 2.0`) +6. BF16/F16 roundtrip — within truncation error; bit pattern identity +7. Mask select — F32Mask16/8, F64Mask8/4; bitmask roundtrip +8. Exotic methods — permute_bytes reverse identity for U8x64/U8x32; + nibble_popcount_lut vs `u32::count_ones` for all 16 nibbles; + shuffle_bytes popcount parity +9. Additional — sqrt/abs/floor/round; to_bits/from_bits roundtrip; + arithmetic ops (BitAnd/Or/Xor); simd_clamp parity + +**Gap noted:** I8x32, I8x64, I16x16, I16x32, I32x16, I64x8 NOT covered +in this batch because agents #5 and #6 hadn't landed when agent #12 ran. +Follow-up: add ~20 signed-int tests to bring total to ~96. + diff --git a/Cargo.toml b/Cargo.toml index b1e30f85..633004a9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -147,6 +147,20 @@ serde = ["dep:serde"] std = ["num-traits/std", "matrixmultiply/std"] rayon = ["dep:rayon", "std"] +# Portable-SIMD backend (NIGHTLY ONLY). Routes `crate::simd::*` types +# through `core::simd::*` instead of the architecture-specific intrinsics +# in `simd_avx512.rs` / `simd_avx2.rs` / `simd_neon.rs`. The point is +# miri compatibility: miri can execute `core::simd` semantics but treats +# `_mm*_*` intrinsics as opaque. With this feature on, miri-run tests +# exercise the actual SIMD code paths in consumer code (`hpc/byte_scan`, +# `hpc/framebuffer`, etc.) and catch UB that the intrinsics backend hides. +# +# Requires `cargo +nightly` because `src/simd_nightly.rs` is gated on +# `#![feature(portable_simd)]` (Rust unstable issue #86656). The default +# build (stable 1.95) does NOT touch this; the existing intrinsics +# cfg-dispatch in `simd.rs` remains the production path. +nightly-simd = ["std"] + # HPC extras: blake3 hashing, p64 palette/NARS bridge, fractal manifold. # These pull in a non-trivial dependency tree; downstream crates such as # burn-ndarray that only need the core array layer can disable this with diff --git a/src/lib.rs b/src/lib.rs index db94e9c4..7c77a07a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,6 +6,12 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. #![crate_name = "ndarray"] +// Crate-level nightly feature gate for the optional `nightly-simd` backend +// (`src/simd_nightly/`). When the `nightly-simd` cargo feature is OFF +// (default), this attribute is absent and stable rustc compiles the crate +// normally. When ON, the crate requires nightly rustc to access +// `core::simd::*` types. +#![cfg_attr(feature = "nightly-simd", feature(portable_simd))] #![doc(html_root_url = "https://docs.rs/ndarray/0.15/")] #![doc(html_logo_url = "https://rust-ndarray.github.io/images/rust-ndarray_logo.svg")] #![allow( @@ -240,6 +246,14 @@ pub(crate) mod simd_avx512; #[allow(clippy::all, missing_docs, dead_code, unused_variables, unused_imports)] pub mod simd_avx2; +// Portable-SIMD backend — nightly-only. Wraps `core::simd::*` so miri can +// execute the polyfill paths (intrinsic-based backends are opaque to +// miri). Gated behind `nightly-simd` feature; the file itself requires +// `#![feature(portable_simd)]` so it only compiles on nightly rustc. +#[cfg(feature = "nightly-simd")] +#[allow(clippy::all, missing_docs)] +pub mod simd_nightly; + #[cfg(feature = "std")] #[allow(clippy::all, missing_docs, dead_code, unused_variables, unused_imports)] // AMX is an x86_64-only ISA (Intel Sapphire Rapids+); the module uses diff --git a/src/simd.rs b/src/simd.rs index 5f37eb4c..ed3e0dea 100644 --- a/src/simd.rs +++ b/src/simd.rs @@ -203,6 +203,15 @@ pub const PREFERRED_I16_LANES: usize = 16; // at compile time → all types use native __m512/__m512d/__m512i. // The 256-bit types (F32x8, F64x4) also live in simd_avx512 (__m256). +// Note on the `nightly-simd` feature: it adds the `crate::simd_nightly` +// module (a portable-simd backend wrapping `core::simd`) but does NOT +// replace the intrinsics dispatch below. Full type-parity coverage +// would require the nightly module to define ~30 types; the current +// draft covers 5 (F32x16, F64x8, U8x64, U32x16, F32Mask16). Consumers +// who want miri-runnable SIMD code import from `simd_nightly` +// explicitly (e.g. `use ndarray::simd_nightly::F32x16`). The main +// polyfill via `crate::simd::F32x16` continues to use intrinsics. + #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] pub use crate::simd_avx512::{ f32x16, diff --git a/src/simd_nightly/_original_draft.rs b/src/simd_nightly/_original_draft.rs new file mode 100644 index 00000000..9b5af799 --- /dev/null +++ b/src/simd_nightly/_original_draft.rs @@ -0,0 +1,674 @@ +//! Portable-SIMD polyfill backend — `core::simd::*` adapter (NIGHTLY ONLY). +//! +//! **Draft.** Not wired into `simd.rs` yet. Gated behind the +//! `nightly-simd` feature (which itself requires nightly Rust because +//! `core::simd` is unstable per ). +//! +//! # Why this exists +//! +//! The default polyfill (`simd_avx512.rs` / `simd_avx2.rs` / `simd_neon.rs`) +//! uses architecture intrinsics (`_mm512_*`, `_mm256_*`, `vaddq_*`). Miri +//! cannot execute those — it treats them as opaque foreign-function calls +//! and skips them. That means miri-runs of consumer code (`hpc/byte_scan`, +//! `compose_neo4j`, the bevy plugin's tick path) never exercise the SIMD +//! branches, so UB inside them (out-of-bounds loads, misaligned reads, +//! aliasing violations) goes undetected. +//! +//! `core::simd` is portable — miri executes the semantics directly. With +//! this backend swapped in via `--features nightly-simd`, the same consumer +//! code becomes miri-checkable on every SIMD path. Bugs that the +//! intrinsics-based backend would silently emit still trip miri here. +//! +//! # Tradeoffs +//! +//! - **Nightly only.** `core::simd` is gated on `#![feature(portable_simd)]`. +//! The default build (stable 1.95) continues to use the intrinsics +//! backends per the existing `simd.rs` cfg dispatch. +//! - **Codegen quality varies.** `core::simd` lowers to platform SIMD via +//! LLVM's portable-simd lowering. On AVX-512 with `-C target-cpu=v4`, +//! the codegen is competitive (LLVM picks `_mm512_*`). On weaker targets +//! it may produce scalar code where the intrinsics path picks AVX2. +//! - **Mask shapes differ.** AVX-512 uses 64-bit bitmask registers +//! (`__mmask64`); `core::simd::Mask` is a vector mask. The +//! `cmpeq_mask` adapter bridges via `to_bitmask()` so consumer code +//! sees the same `u64` shape. +//! +//! # Status +//! +//! Coverage in this draft: +//! - `F32x16`, `F64x8` — full method set (splat, FMA, reductions, +//! comparisons → bitmask, simd_clamp, sqrt/floor/round/abs). +//! - `U8x64` — load/store + reductions + `cmpeq_mask → u64` + +//! `saturating_add/sub` + `pairwise_avg`. +//! - Operator impls (`Add`, `Sub`, `Mul`, `Div`, `BitAnd`, `BitOr`, +//! `BitXor`) via macro. +//! +//! NOT covered (will deny-compile if a consumer reaches for them under +//! `--features nightly-simd`): +//! - `permute_bytes`, `shuffle_bytes` — `core::simd::Simd::swizzle` is +//! `const N` so it can't take a runtime `idx` vector. Needs a scalar +//! fallback like the AVX-512F-without-VBMI path in `simd_avx512.rs`. +//! - `BF16x16`, `BF16x8`, `F16x16` — `core::simd` has no half-precision +//! types; needs a manual `bf16_to_f32_batch` adapter. +//! - AMX tile types — out of scope for portable simd entirely. +//! - 32-/64-/128-bit integer SIMD types (`I32x16`, `U32x16`, `I64x8`, +//! `U64x8`, `U16x32`) — straightforward to add via the same macro +//! pattern; left for a follow-up commit. +//! +//! # Wiring into simd.rs (sketch — not done in this draft) +//! +//! ```rust,ignore +//! // In src/simd.rs, add a high-priority cfg branch: +//! #[cfg(feature = "nightly-simd")] +//! pub use crate::simd_nightly::{F32x16, F64x8, U8x64}; +//! +//! #[cfg(all(not(feature = "nightly-simd"), target_arch = "x86_64", +//! target_feature = "avx512f"))] +//! pub use crate::simd_avx512::{F32x16, F64x8, U8x64}; +//! // ... existing branches +//! ``` +//! +//! # Cargo.toml (sketch — not done in this draft) +//! +//! ```toml +//! [features] +//! nightly-simd = [] # requires nightly rustc; enables core::simd backend +//! ``` + +#![cfg(feature = "nightly-simd")] +#![feature(portable_simd)] + +use core::simd::cmp::{SimdPartialEq, SimdPartialOrd}; +use core::simd::num::{SimdFloat, SimdInt, SimdUint}; +use core::simd::{f32x16 as core_f32x16, f64x8 as core_f64x8, u8x64 as core_u8x64}; +use core::simd::{Mask, Simd, ToBitMask}; + +// ════════════════════════════════════════════════════════════════════ +// F32x16 — 16-lane single-precision float +// ════════════════════════════════════════════════════════════════════ + +/// 16-lane `f32` SIMD vector backed by `core::simd::f32x16`. +/// +/// API mirrors `simd_avx512::F32x16` so consumer code is identical. +/// Miri can execute every method below — unlike the intrinsics +/// backend, where SIMD paths are opaque to miri. +#[derive(Copy, Clone, Debug)] +#[repr(transparent)] +pub struct F32x16(pub core_f32x16); + +impl F32x16 { + pub const LANES: usize = 16; + + #[inline(always)] + pub fn splat(v: f32) -> Self { + Self(core_f32x16::splat(v)) + } + + #[inline(always)] + pub fn from_array(arr: [f32; 16]) -> Self { + Self(core_f32x16::from_array(arr)) + } + + #[inline(always)] + pub fn from_slice(s: &[f32]) -> Self { + assert!(s.len() >= 16, "F32x16::from_slice needs ≥16 elements"); + Self(core_f32x16::from_slice(s)) + } + + #[inline(always)] + pub fn to_array(self) -> [f32; 16] { + self.0.to_array() + } + + #[inline(always)] + pub fn copy_to_slice(self, s: &mut [f32]) { + assert!(s.len() >= 16, "F32x16::copy_to_slice needs ≥16 elements"); + self.0.copy_to_slice(s); + } + + // ── Reductions ──────────────────────────────────────────────── + + #[inline(always)] + pub fn reduce_sum(self) -> f32 { + self.0.reduce_sum() + } + + #[inline(always)] + pub fn reduce_min(self) -> f32 { + self.0.reduce_min() + } + + #[inline(always)] + pub fn reduce_max(self) -> f32 { + self.0.reduce_max() + } + + // ── Lane-wise min/max/clamp ─────────────────────────────────── + + #[inline(always)] + pub fn simd_min(self, other: Self) -> Self { + Self(self.0.simd_min(other.0)) + } + + #[inline(always)] + pub fn simd_max(self, other: Self) -> Self { + Self(self.0.simd_max(other.0)) + } + + #[inline(always)] + pub fn simd_clamp(self, lo: Self, hi: Self) -> Self { + Self(self.0.simd_clamp(lo.0, hi.0)) + } + + // ── FMA + math ──────────────────────────────────────────────── + + /// Fused multiply-add: `self * b + c`. Maps to `_mm512_fmadd_ps` on + /// AVX-512 builds via LLVM's portable-simd lowering, scalar + /// `f32::mul_add` per lane otherwise. + #[inline(always)] + pub fn mul_add(self, b: Self, c: Self) -> Self { + Self(self.0.mul_add(b.0, c.0)) + } + + #[inline(always)] + pub fn sqrt(self) -> Self { + Self(self.0.sqrt()) + } + + #[inline(always)] + pub fn floor(self) -> Self { + Self(self.0.floor()) + } + + #[inline(always)] + pub fn round(self) -> Self { + Self(self.0.round()) + } + + #[inline(always)] + pub fn abs(self) -> Self { + Self(self.0.abs()) + } + + // ── Bit reinterpretation ────────────────────────────────────── + + #[inline(always)] + pub fn to_bits(self) -> U32x16 { + U32x16(self.0.to_bits()) + } + + #[inline(always)] + pub fn from_bits(bits: U32x16) -> Self { + Self(core_f32x16::from_bits(bits.0)) + } + + // ── Comparisons → mask ──────────────────────────────────────── + // + // The intrinsics backend exposes `simd_eq` / `simd_lt` / etc. + // returning `F32Mask16(__mmask16)` (a 16-bit bitmask). We mirror + // that by wrapping `core::simd::Mask` and providing + // `to_bitmask()` for consumers that want the u16 shape. + + #[inline(always)] + pub fn simd_eq(self, other: Self) -> F32Mask16 { + F32Mask16(self.0.simd_eq(other.0)) + } + + #[inline(always)] + pub fn simd_ne(self, other: Self) -> F32Mask16 { + F32Mask16(self.0.simd_ne(other.0)) + } + + #[inline(always)] + pub fn simd_lt(self, other: Self) -> F32Mask16 { + F32Mask16(self.0.simd_lt(other.0)) + } + + #[inline(always)] + pub fn simd_le(self, other: Self) -> F32Mask16 { + F32Mask16(self.0.simd_le(other.0)) + } + + #[inline(always)] + pub fn simd_gt(self, other: Self) -> F32Mask16 { + F32Mask16(self.0.simd_gt(other.0)) + } + + #[inline(always)] + pub fn simd_ge(self, other: Self) -> F32Mask16 { + F32Mask16(self.0.simd_ge(other.0)) + } +} + +/// 16-lane mask for `F32x16` comparisons. +#[derive(Copy, Clone, Debug)] +pub struct F32Mask16(pub Mask); + +impl F32Mask16 { + /// Convert to a 16-bit packed bitmask (matches the AVX-512 `__mmask16` + /// shape). Bit i set iff lane i of the mask is true. + #[inline(always)] + pub fn to_bitmask(self) -> u16 { + self.0.to_bitmask() as u16 + } + + /// Per-lane select: returns `true_val[i]` where mask[i] is set, + /// else `false_val[i]`. + #[inline(always)] + pub fn select(self, true_val: F32x16, false_val: F32x16) -> F32x16 { + F32x16(self.0.select(true_val.0, false_val.0)) + } +} + +// ════════════════════════════════════════════════════════════════════ +// F64x8 — 8-lane double-precision float +// ════════════════════════════════════════════════════════════════════ + +#[derive(Copy, Clone, Debug)] +#[repr(transparent)] +pub struct F64x8(pub core_f64x8); + +impl F64x8 { + pub const LANES: usize = 8; + + #[inline(always)] + pub fn splat(v: f64) -> Self { + Self(core_f64x8::splat(v)) + } + + #[inline(always)] + pub fn from_array(arr: [f64; 8]) -> Self { + Self(core_f64x8::from_array(arr)) + } + + #[inline(always)] + pub fn from_slice(s: &[f64]) -> Self { + assert!(s.len() >= 8, "F64x8::from_slice needs ≥8 elements"); + Self(core_f64x8::from_slice(s)) + } + + #[inline(always)] + pub fn to_array(self) -> [f64; 8] { + self.0.to_array() + } + + #[inline(always)] + pub fn copy_to_slice(self, s: &mut [f64]) { + assert!(s.len() >= 8, "F64x8::copy_to_slice needs ≥8 elements"); + self.0.copy_to_slice(s); + } + + #[inline(always)] + pub fn reduce_sum(self) -> f64 { + self.0.reduce_sum() + } + #[inline(always)] + pub fn reduce_min(self) -> f64 { + self.0.reduce_min() + } + #[inline(always)] + pub fn reduce_max(self) -> f64 { + self.0.reduce_max() + } + + #[inline(always)] + pub fn simd_min(self, other: Self) -> Self { + Self(self.0.simd_min(other.0)) + } + #[inline(always)] + pub fn simd_max(self, other: Self) -> Self { + Self(self.0.simd_max(other.0)) + } + #[inline(always)] + pub fn simd_clamp(self, lo: Self, hi: Self) -> Self { + Self(self.0.simd_clamp(lo.0, hi.0)) + } + + #[inline(always)] + pub fn mul_add(self, b: Self, c: Self) -> Self { + Self(self.0.mul_add(b.0, c.0)) + } + #[inline(always)] + pub fn sqrt(self) -> Self { + Self(self.0.sqrt()) + } + #[inline(always)] + pub fn floor(self) -> Self { + Self(self.0.floor()) + } + #[inline(always)] + pub fn round(self) -> Self { + Self(self.0.round()) + } + #[inline(always)] + pub fn abs(self) -> Self { + Self(self.0.abs()) + } +} + +// ════════════════════════════════════════════════════════════════════ +// U8x64 — 64-lane unsigned-byte (the rasterizer / palette / NBT width) +// ════════════════════════════════════════════════════════════════════ + +#[derive(Copy, Clone, Debug)] +#[repr(transparent)] +pub struct U8x64(pub core_u8x64); + +impl U8x64 { + pub const LANES: usize = 64; + + #[inline(always)] + pub fn splat(v: u8) -> Self { + Self(core_u8x64::splat(v)) + } + #[inline(always)] + pub fn from_array(arr: [u8; 64]) -> Self { + Self(core_u8x64::from_array(arr)) + } + #[inline(always)] + pub fn from_slice(s: &[u8]) -> Self { + assert!(s.len() >= 64, "U8x64::from_slice needs ≥64 bytes"); + Self(core_u8x64::from_slice(s)) + } + #[inline(always)] + pub fn to_array(self) -> [u8; 64] { + self.0.to_array() + } + #[inline(always)] + pub fn copy_to_slice(self, s: &mut [u8]) { + assert!(s.len() >= 64, "U8x64::copy_to_slice needs ≥64 bytes"); + self.0.copy_to_slice(s); + } + + // ── Reductions ──────────────────────────────────────────────── + + #[inline(always)] + pub fn reduce_sum(self) -> u8 { + self.0.reduce_sum() + } + #[inline(always)] + pub fn reduce_min(self) -> u8 { + self.0.reduce_min() + } + #[inline(always)] + pub fn reduce_max(self) -> u8 { + self.0.reduce_max() + } + + // ── Lane-wise min / max ─────────────────────────────────────── + + #[inline(always)] + pub fn simd_min(self, other: Self) -> Self { + Self(self.0.simd_min(other.0)) + } + #[inline(always)] + pub fn simd_max(self, other: Self) -> Self { + Self(self.0.simd_max(other.0)) + } + + // ── Saturating arithmetic ───────────────────────────────────── + + #[inline(always)] + pub fn saturating_add(self, other: Self) -> Self { + Self(self.0.saturating_add(other.0)) + } + #[inline(always)] + pub fn saturating_sub(self, other: Self) -> Self { + Self(self.0.saturating_sub(other.0)) + } + + /// Per-lane unsigned rounded average: `(a + b + 1) >> 1`. + /// `core::simd` has no native `avg_epu8` equivalent, so we compute + /// it from the standard identity. LLVM may still lower to + /// `_mm512_avg_epu8` on AVX-512 builds; verify with --emit asm. + #[inline(always)] + pub fn pairwise_avg(self, other: Self) -> Self { + // Promote to u16 to avoid overflow on `a + b + 1`. + let a16 = self.0.cast::(); + let b16 = other.0.cast::(); + let avg = (a16 + b16 + core::simd::Simd::splat(1)) >> core::simd::Simd::splat(1); + Self(avg.cast::()) + } + + // ── Comparison → 64-bit bitmask ─────────────────────────────── + + /// Per-lane equality. Returns a 64-bit mask (bit i = 1 iff + /// `self[i] == other[i]`). Matches the AVX-512 `__mmask64` shape + /// of `simd_avx512::U8x64::cmpeq_mask`. + #[inline(always)] + pub fn cmpeq_mask(self, other: Self) -> u64 { + self.0.simd_eq(other.0).to_bitmask() + } + + /// Per-lane unsigned greater-than. Returns a 64-bit mask. + #[inline(always)] + pub fn cmpgt_mask(self, other: Self) -> u64 { + self.0.simd_gt(other.0).to_bitmask() + } + + /// Extract MSB of each byte as a 64-bit mask. Equivalent to + /// `_mm512_movepi8_mask`. + #[inline(always)] + pub fn movemask(self) -> u64 { + // MSB-set ⇔ value ≥ 128 ⇔ value > 127 in unsigned cmp. + self.0.simd_gt(core_u8x64::splat(0x7F)).to_bitmask() + } +} + +// ════════════════════════════════════════════════════════════════════ +// U32x16 — companion type for F32x16::to_bits / from_bits +// ════════════════════════════════════════════════════════════════════ + +#[derive(Copy, Clone, Debug)] +#[repr(transparent)] +pub struct U32x16(pub Simd); + +impl U32x16 { + pub const LANES: usize = 16; + + #[inline(always)] + pub fn splat(v: u32) -> Self { + Self(Simd::splat(v)) + } + #[inline(always)] + pub fn from_array(arr: [u32; 16]) -> Self { + Self(Simd::from_array(arr)) + } + #[inline(always)] + pub fn to_array(self) -> [u32; 16] { + self.0.to_array() + } +} + +// ════════════════════════════════════════════════════════════════════ +// Operator impls — `Add`, `Sub`, `Mul`, `Div` for floats; bitwise for +// ints. Macro keeps the boilerplate from spreading. +// ════════════════════════════════════════════════════════════════════ + +macro_rules! impl_fp_ops { + ($name:ident) => { + impl core::ops::Add for $name { + type Output = Self; + #[inline(always)] + fn add(self, rhs: Self) -> Self { + Self(self.0 + rhs.0) + } + } + impl core::ops::Sub for $name { + type Output = Self; + #[inline(always)] + fn sub(self, rhs: Self) -> Self { + Self(self.0 - rhs.0) + } + } + impl core::ops::Mul for $name { + type Output = Self; + #[inline(always)] + fn mul(self, rhs: Self) -> Self { + Self(self.0 * rhs.0) + } + } + impl core::ops::Div for $name { + type Output = Self; + #[inline(always)] + fn div(self, rhs: Self) -> Self { + Self(self.0 / rhs.0) + } + } + impl core::ops::Neg for $name { + type Output = Self; + #[inline(always)] + fn neg(self) -> Self { + Self(-self.0) + } + } + }; +} + +macro_rules! impl_int_ops { + ($name:ident) => { + impl core::ops::Add for $name { + type Output = Self; + #[inline(always)] + fn add(self, rhs: Self) -> Self { + Self(self.0 + rhs.0) + } + } + impl core::ops::Sub for $name { + type Output = Self; + #[inline(always)] + fn sub(self, rhs: Self) -> Self { + Self(self.0 - rhs.0) + } + } + impl core::ops::BitAnd for $name { + type Output = Self; + #[inline(always)] + fn bitand(self, rhs: Self) -> Self { + Self(self.0 & rhs.0) + } + } + impl core::ops::BitOr for $name { + type Output = Self; + #[inline(always)] + fn bitor(self, rhs: Self) -> Self { + Self(self.0 | rhs.0) + } + } + impl core::ops::BitXor for $name { + type Output = Self; + #[inline(always)] + fn bitxor(self, rhs: Self) -> Self { + Self(self.0 ^ rhs.0) + } + } + }; +} + +impl_fp_ops!(F32x16); +impl_fp_ops!(F64x8); +impl_int_ops!(U8x64); +impl_int_ops!(U32x16); + +// ════════════════════════════════════════════════════════════════════ +// Default impls — `Self::splat(0)` shape. +// ════════════════════════════════════════════════════════════════════ + +impl Default for F32x16 { + #[inline(always)] + fn default() -> Self { + Self::splat(0.0) + } +} +impl Default for F64x8 { + #[inline(always)] + fn default() -> Self { + Self::splat(0.0) + } +} +impl Default for U8x64 { + #[inline(always)] + fn default() -> Self { + Self::splat(0) + } +} +impl Default for U32x16 { + #[inline(always)] + fn default() -> Self { + Self::splat(0) + } +} + +// ════════════════════════════════════════════════════════════════════ +// Tests — small parity sanity for the methods consumers rely on. +// Real test surface (miri-runnable) lives in `tests/` once wired. +// ════════════════════════════════════════════════════════════════════ + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn f32x16_fma_lane_exact() { + let a = F32x16::splat(0.5); + let b = F32x16::splat(2.0); + let c = F32x16::splat(1.0); + // 0.5 * 2.0 + 1.0 = 2.0 per lane + assert!(a.mul_add(b, c).to_array().iter().all(|&v| (v - 2.0).abs() < 1e-6)); + } + + #[test] + fn f32x16_mask_bitmask_roundtrip() { + let v = F32x16::from_array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]); + let threshold = F32x16::splat(8.5); + let m = v.simd_lt(threshold); + // First 8 lanes < 8.5, rest aren't. + assert_eq!(m.to_bitmask(), 0b0000_0000_1111_1111); + } + + #[test] + fn u8x64_cmpeq_mask_matches_scalar() { + let a: [u8; 64] = core::array::from_fn(|i| (i & 7) as u8); + let b: [u8; 64] = core::array::from_fn(|i| (i & 6) as u8); + let m = U8x64::from_array(a).cmpeq_mask(U8x64::from_array(b)); + for i in 0..64 { + assert_eq!(((m >> i) & 1) == 1, a[i] == b[i], "lane {i}"); + } + } + + #[test] + fn u8x64_saturating_add_clamps_to_255() { + let a = U8x64::splat(200); + let b = U8x64::splat(100); + assert_eq!(a.saturating_add(b).to_array(), [255u8; 64]); + } + + #[test] + fn u8x64_pairwise_avg_rounds_up() { + let a = U8x64::splat(7); + let b = U8x64::splat(8); + // (7 + 8 + 1) / 2 = 8 + assert_eq!(a.pairwise_avg(b).to_array(), [8u8; 64]); + } + + #[test] + fn integrate_simd_shape_via_polyfill() { + // The hot loop in hpc::renderer::integrate_simd looks like: + // for chunk in positions.chunks_exact_mut(16) { + // let p = F32x16::from_slice(chunk); + // let v = F32x16::from_slice(velocity_chunk); + // let dt_v = F32x16::splat(dt); + // let p_new = v.mul_add(dt_v, p); + // p_new.copy_to_slice(chunk); + // } + // Verify the same shape works through the nightly backend. + let mut p = [0.0f32; 16]; + let v = F32x16::splat(1.0); + let dt = F32x16::splat(1.0 / 60.0); + let p0 = F32x16::from_slice(&p); + v.mul_add(dt, p0).copy_to_slice(&mut p); + for x in p { + assert!((x - 1.0 / 60.0).abs() < 1e-6); + } + } +} diff --git a/src/simd_nightly/bf16_types.rs b/src/simd_nightly/bf16_types.rs new file mode 100644 index 00000000..e6c408ac --- /dev/null +++ b/src/simd_nightly/bf16_types.rs @@ -0,0 +1,269 @@ +//! BF16x16 / BF16x8 portable-simd wrappers (scalar emulation — no core::simd +//! half-precision types). Round-3-portable-simd agent #8. +#![cfg(feature = "nightly-simd")] + +/// 16-lane BF16 vector backed by `[u16; 16]` (scalar emulation). +/// +/// `core::simd` has no native half-precision type, so bit patterns are stored +/// as `u16` and operations upcast through `f32` where needed. +#[derive(Copy, Clone, Debug)] +#[repr(transparent)] +pub struct BF16x16(pub [u16; 16]); + +/// 8-lane BF16 vector backed by `[u16; 8]` (scalar emulation). +#[derive(Copy, Clone, Debug)] +#[repr(transparent)] +pub struct BF16x8(pub [u16; 8]); + +// --------------------------------------------------------------------------- +// Shared conversion helpers +// --------------------------------------------------------------------------- + +/// Truncate an `f32` to BF16 bit pattern (drop low 16 mantissa bits). +#[inline(always)] +fn f32_to_bf16_bits(v: f32) -> u16 { + (v.to_bits() >> 16) as u16 +} + +/// Reconstruct an `f32` from a BF16 bit pattern (lossless). +#[inline(always)] +fn bf16_bits_to_f32(bits: u16) -> f32 { + f32::from_bits((bits as u32) << 16) +} + +// --------------------------------------------------------------------------- +// BF16x16 +// --------------------------------------------------------------------------- + +impl BF16x16 { + /// Number of BF16 lanes. + pub const LANES: usize = 16; + + /// Broadcast `v` (converted f32 → BF16 via truncation) across all 16 lanes. + #[inline] + pub fn splat(v: f32) -> Self { + Self([f32_to_bf16_bits(v); 16]) + } + + /// Load 16 BF16 bit-patterns from a `u16` slice (must have `len >= 16`). + #[inline] + pub fn from_slice(s: &[u16]) -> Self { + assert!(s.len() >= 16, "BF16x16::from_slice: need >= 16 elements, got {}", s.len()); + let mut arr = [0u16; 16]; + arr.copy_from_slice(&s[..16]); + Self(arr) + } + + /// Construct from an array of 16 BF16 bit-patterns. + #[inline] + pub fn from_array(arr: [u16; 16]) -> Self { + Self(arr) + } + + /// Return the underlying array of 16 BF16 bit-patterns. + #[inline] + pub fn to_array(self) -> [u16; 16] { + self.0 + } + + /// Write the 16 BF16 bit-patterns into `dst` (must have `len >= 16`). + #[inline] + pub fn copy_to_slice(self, dst: &mut [u16]) { + assert!(dst.len() >= 16, "BF16x16::copy_to_slice: need >= 16 elements, got {}", dst.len()); + dst[..16].copy_from_slice(&self.0); + } + + /// Convert all 16 BF16 lanes to `f32` (lossless: zero-extend bits 15:0 → bits 31:16). + #[inline] + pub fn to_f32_lossy(self) -> [f32; 16] { + let mut out = [0.0f32; 16]; + for i in 0..16 { + out[i] = bf16_bits_to_f32(self.0[i]); + } + out + } + + /// Build from 16 `f32` values, truncating (not rounding) the low 16 mantissa bits. + #[inline] + pub fn from_f32_truncate(arr: [f32; 16]) -> Self { + let mut out = [0u16; 16]; + for i in 0..16 { + out[i] = f32_to_bf16_bits(arr[i]); + } + Self(out) + } +} + +// --------------------------------------------------------------------------- +// BF16x8 +// --------------------------------------------------------------------------- + +impl BF16x8 { + /// Number of BF16 lanes. + pub const LANES: usize = 8; + + /// Broadcast `v` (converted f32 → BF16 via truncation) across all 8 lanes. + #[inline] + pub fn splat(v: f32) -> Self { + Self([f32_to_bf16_bits(v); 8]) + } + + /// Load 8 BF16 bit-patterns from a `u16` slice (must have `len >= 8`). + #[inline] + pub fn from_slice(s: &[u16]) -> Self { + assert!(s.len() >= 8, "BF16x8::from_slice: need >= 8 elements, got {}", s.len()); + let mut arr = [0u16; 8]; + arr.copy_from_slice(&s[..8]); + Self(arr) + } + + /// Construct from an array of 8 BF16 bit-patterns. + #[inline] + pub fn from_array(arr: [u16; 8]) -> Self { + Self(arr) + } + + /// Return the underlying array of 8 BF16 bit-patterns. + #[inline] + pub fn to_array(self) -> [u16; 8] { + self.0 + } + + /// Write the 8 BF16 bit-patterns into `dst` (must have `len >= 8`). + #[inline] + pub fn copy_to_slice(self, dst: &mut [u16]) { + assert!(dst.len() >= 8, "BF16x8::copy_to_slice: need >= 8 elements, got {}", dst.len()); + dst[..8].copy_from_slice(&self.0); + } + + /// Convert all 8 BF16 lanes to `f32` (lossless). + #[inline] + pub fn to_f32_lossy(self) -> [f32; 8] { + let mut out = [0.0f32; 8]; + for i in 0..8 { + out[i] = bf16_bits_to_f32(self.0[i]); + } + out + } + + /// Build from 8 `f32` values, truncating the low 16 mantissa bits. + #[inline] + pub fn from_f32_truncate(arr: [f32; 8]) -> Self { + let mut out = [0u16; 8]; + for i in 0..8 { + out[i] = f32_to_bf16_bits(arr[i]); + } + Self(out) + } +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + // -- BF16x16 ------------------------------------------------------------ + + #[test] + fn bf16x16_splat_roundtrip() { + let v = BF16x16::splat(3.14f32); + let f32s = v.to_f32_lossy(); + // splat then to_f32_lossy must all be identical + assert!(f32s.windows(2).all(|w| w[0] == w[1])); + } + + #[test] + fn bf16x16_from_f32_truncate_to_f32_lossy() { + let input: [f32; 16] = core::array::from_fn(|i| i as f32 * 1.5); + let vec = BF16x16::from_f32_truncate(input); + let out = vec.to_f32_lossy(); + for i in 0..16 { + // truncation then expand must match bit-shift arithmetic + let expected = f32::from_bits((input[i].to_bits() >> 16) << 16); + assert_eq!(out[i], expected, "lane {i} mismatch"); + } + } + + #[test] + fn bf16x16_from_slice_copy_to_slice_roundtrip() { + let bits: [u16; 16] = core::array::from_fn(|i| (i as u16) * 0x100); + let v = BF16x16::from_slice(&bits); + let mut out = [0u16; 16]; + v.copy_to_slice(&mut out); + assert_eq!(bits, out); + } + + #[test] + fn bf16x16_from_array_to_array_roundtrip() { + let arr: [u16; 16] = core::array::from_fn(|i| i as u16 + 1); + assert_eq!(BF16x16::from_array(arr).to_array(), arr); + } + + #[test] + fn bf16x16_lanes_const() { + assert_eq!(BF16x16::LANES, 16); + } + + // -- BF16x8 ------------------------------------------------------------- + + #[test] + fn bf16x8_splat_roundtrip() { + let v = BF16x8::splat(2.71f32); + let f32s = v.to_f32_lossy(); + assert!(f32s.windows(2).all(|w| w[0] == w[1])); + } + + #[test] + fn bf16x8_from_f32_truncate_to_f32_lossy() { + let input: [f32; 8] = core::array::from_fn(|i| i as f32 * 0.5); + let vec = BF16x8::from_f32_truncate(input); + let out = vec.to_f32_lossy(); + for i in 0..8 { + let expected = f32::from_bits((input[i].to_bits() >> 16) << 16); + assert_eq!(out[i], expected, "lane {i} mismatch"); + } + } + + #[test] + fn bf16x8_from_slice_copy_to_slice_roundtrip() { + let bits: [u16; 8] = core::array::from_fn(|i| (i as u16) * 0x200); + let v = BF16x8::from_slice(&bits); + let mut out = [0u16; 8]; + v.copy_to_slice(&mut out); + assert_eq!(bits, out); + } + + #[test] + fn bf16x8_from_array_to_array_roundtrip() { + let arr: [u16; 8] = core::array::from_fn(|i| i as u16 + 10); + assert_eq!(BF16x8::from_array(arr).to_array(), arr); + } + + #[test] + fn bf16x8_lanes_const() { + assert_eq!(BF16x8::LANES, 8); + } + + // -- Known BF16 bit patterns ------------------------------------------- + + #[test] + fn bf16_one_point_zero() { + // 1.0f32 = 0x3F800000; BF16 = 0x3F80 + let v = BF16x16::splat(1.0f32); + assert_eq!(v.0[0], 0x3F80u16); + let f32s = v.to_f32_lossy(); + assert_eq!(f32s[0], 1.0f32); + } + + #[test] + fn bf16_negative_one() { + // -1.0f32 = 0xBF800000; BF16 = 0xBF80 + let v = BF16x8::splat(-1.0f32); + assert_eq!(v.0[0], 0xBF80u16); + let f32s = v.to_f32_lossy(); + assert_eq!(f32s[0], -1.0f32); + } +} diff --git a/src/simd_nightly/exotic_methods.rs b/src/simd_nightly/exotic_methods.rs new file mode 100644 index 00000000..baa6ede2 --- /dev/null +++ b/src/simd_nightly/exotic_methods.rs @@ -0,0 +1,564 @@ +//! Scalar fallbacks for U8x32 / U8x64 methods `core::simd` doesn't +//! natively support (cross-lane permute, within-lane shuffle, bitmask +//! blend, lane interleave, nibble-popcount LUT). +//! Round-3-portable-simd agent #11. +#![cfg(feature = "nightly-simd")] + +use super::u8_types::{U8x32, U8x64}; + +// ════════════════════════════════════════════════════════════════════ +// U8x64 extension methods +// ════════════════════════════════════════════════════════════════════ + +impl U8x64 { + /// Cross-lane byte permute: rearrange all 64 bytes by index vector. + /// + /// `idx[i]` selects which byte of `self` appears at position `i`. + /// The low 6 bits of each index are used (`idx[i] & 63`). + /// + /// `core::simd::Swizzle::swizzle` requires a `const N: usize` index + /// and cannot take a runtime `idx` vector. This scalar fallback via + /// `to_array()` / `from_array()` matches the AVX-512F-without-VBMI + /// path in `simd_avx512.rs::U8x64::permute_bytes` (lines ~695–710). + /// + /// # Examples + /// ```rust + /// # #[cfg(feature = "nightly-simd")] { + /// use ndarray::simd_nightly::U8x64; + /// let v: U8x64 = U8x64::from_array(core::array::from_fn(|i| i as u8)); + /// // Identity permute + /// let idx: U8x64 = U8x64::from_array(core::array::from_fn(|i| i as u8)); + /// assert_eq!(v.permute_bytes(idx).to_array(), v.to_array()); + /// # } + /// ``` + #[inline] + pub fn permute_bytes(self, idx: Self) -> Self { + let src = self.to_array(); + let idx_arr = idx.to_array(); + let mut out = [0u8; 64]; + for i in 0..64 { + out[i] = src[(idx_arr[i] & 63) as usize]; + } + Self::from_array(out) + } + + /// Within-128-bit-lane byte shuffle. + /// + /// `self` is the LUT; `idx[i]` selects within the same 16-byte lane. + /// If the high bit (`0x80`) of `idx[i]` is set, the output lane is + /// zeroed. Only the low 4 bits select the source index within the lane. + /// Equivalent to `_mm512_shuffle_epi8` semantics. + /// + /// # Examples + /// ```rust + /// # #[cfg(feature = "nightly-simd")] { + /// use ndarray::simd_nightly::U8x64; + /// let lut = U8x64::nibble_popcount_lut(); + /// // All zeros → nibble popcount 0 + /// let idx = U8x64::splat(0); + /// assert_eq!(lut.shuffle_bytes(idx).to_array()[0], 0); + /// // High bit set → zeroed + /// let zero_idx = U8x64::splat(0x80); + /// assert_eq!(lut.shuffle_bytes(zero_idx).to_array()[0], 0); + /// # } + /// ``` + #[inline] + pub fn shuffle_bytes(self, idx: Self) -> Self { + let src = self.to_array(); + let idx_arr = idx.to_array(); + let mut out = [0u8; 64]; + for lane in 0..4 { + let base = lane * 16; + for i in 0..16 { + let ix = idx_arr[base + i]; + // High bit set → zero; low 4 bits → index within lane. + out[base + i] = if ix & 0x80 != 0 { + 0 + } else { + src[base + (ix & 0x0F) as usize] + }; + } + } + Self::from_array(out) + } + + /// Bitmask-driven blend: select `b[i]` where bit `i` of `mask` is set, + /// else `a[i]`. + /// + /// Mirrors `_mm512_mask_blend_epi8(mask, a, b)` semantics. + /// + /// # Examples + /// ```rust + /// # #[cfg(feature = "nightly-simd")] { + /// use ndarray::simd_nightly::U8x64; + /// let a = U8x64::splat(10); + /// let b = U8x64::splat(20); + /// // Bit 0 set → lane 0 comes from b. + /// let r = U8x64::mask_blend(1u64, a, b); + /// assert_eq!(r.to_array()[0], 20); + /// assert_eq!(r.to_array()[1], 10); + /// # } + /// ``` + #[inline] + pub fn mask_blend(mask: u64, a: Self, b: Self) -> Self { + let a_arr = a.to_array(); + let b_arr = b.to_array(); + let mut out = [0u8; 64]; + for i in 0..64 { + out[i] = if mask & (1u64 << i) != 0 { b_arr[i] } else { a_arr[i] }; + } + Self::from_array(out) + } + + /// Interleave low bytes within each 128-bit lane. + /// + /// For each of the 4 × 16-byte lanes, takes the first 8 bytes of + /// `self` and `other` and interleaves them: + /// `[self[0], other[0], self[1], other[1], ..., self[7], other[7]]`. + /// + /// Mirrors `_mm512_unpacklo_epi8` semantics. + /// + /// # Examples + /// ```rust + /// # #[cfg(feature = "nightly-simd")] { + /// use ndarray::simd_nightly::U8x64; + /// let a = U8x64::splat(0xAA); + /// let b = U8x64::splat(0xBB); + /// let r = a.unpack_lo_epi8(b); + /// assert_eq!(r.to_array()[0], 0xAA); + /// assert_eq!(r.to_array()[1], 0xBB); + /// # } + /// ``` + #[inline] + pub fn unpack_lo_epi8(self, other: Self) -> Self { + let a = self.to_array(); + let b = other.to_array(); + let mut out = [0u8; 64]; + // 4 × 128-bit lanes, each 16 bytes wide. + for lane in 0..4 { + let base = lane * 16; + for i in 0..8 { + out[base + i * 2] = a[base + i]; + out[base + i * 2 + 1] = b[base + i]; + } + } + Self::from_array(out) + } + + /// Interleave high bytes within each 128-bit lane. + /// + /// For each of the 4 × 16-byte lanes, takes the last 8 bytes of + /// `self` and `other` and interleaves them: + /// `[self[8], other[8], self[9], other[9], ..., self[15], other[15]]`. + /// + /// Mirrors `_mm512_unpackhi_epi8` semantics. + /// + /// # Examples + /// ```rust + /// # #[cfg(feature = "nightly-simd")] { + /// use ndarray::simd_nightly::U8x64; + /// let a = U8x64::splat(0xAA); + /// let b = U8x64::splat(0xBB); + /// let r = a.unpack_hi_epi8(b); + /// assert_eq!(r.to_array()[0], 0xAA); + /// assert_eq!(r.to_array()[1], 0xBB); + /// # } + /// ``` + #[inline] + pub fn unpack_hi_epi8(self, other: Self) -> Self { + let a = self.to_array(); + let b = other.to_array(); + let mut out = [0u8; 64]; + for lane in 0..4 { + let base = lane * 16; + for i in 0..8 { + out[base + i * 2] = a[base + 8 + i]; + out[base + i * 2 + 1] = b[base + 8 + i]; + } + } + Self::from_array(out) + } +} + +// ════════════════════════════════════════════════════════════════════ +// U8x32 extension methods +// ════════════════════════════════════════════════════════════════════ + +impl U8x32 { + /// Cross-lane byte permute: rearrange all 32 bytes by index vector. + /// + /// `idx[i]` selects which byte of `self` appears at position `i`. + /// The low 5 bits of each index are used (`idx[i] & 31`). + /// + /// `core::simd::Swizzle::swizzle` requires a `const N: usize` index + /// and cannot take a runtime `idx` vector. Scalar fallback. + /// + /// # Examples + /// ```rust + /// # #[cfg(feature = "nightly-simd")] { + /// use ndarray::simd_nightly::U8x32; + /// let v = U8x32::from_array(core::array::from_fn(|i| i as u8)); + /// let idx = U8x32::from_array(core::array::from_fn(|i| (31 - i) as u8)); + /// let r = v.permute_bytes(idx); + /// assert_eq!(r.to_array()[0], 31); + /// assert_eq!(r.to_array()[31], 0); + /// # } + /// ``` + #[inline] + pub fn permute_bytes(self, idx: Self) -> Self { + let src = self.to_array(); + let idx_arr = idx.to_array(); + let mut out = [0u8; 32]; + for i in 0..32 { + out[i] = src[(idx_arr[i] & 31) as usize]; + } + Self::from_array(out) + } + + /// Within-128-bit-lane byte shuffle. + /// + /// `self` is the LUT; `idx[i]` selects within the same 16-byte lane. + /// If the high bit (`0x80`) of `idx[i]` is set, the output lane is + /// zeroed. Only the low 4 bits select the source index within the lane. + /// Equivalent to `_mm256_shuffle_epi8` semantics. + /// + /// # Examples + /// ```rust + /// # #[cfg(feature = "nightly-simd")] { + /// use ndarray::simd_nightly::U8x32; + /// let lut = U8x32::nibble_popcount_lut(); + /// let idx = U8x32::splat(0); + /// assert_eq!(lut.shuffle_bytes(idx).to_array()[0], 0); + /// // High bit set → zeroed + /// let zero_idx = U8x32::splat(0x80); + /// assert_eq!(lut.shuffle_bytes(zero_idx).to_array()[0], 0); + /// # } + /// ``` + #[inline] + pub fn shuffle_bytes(self, idx: Self) -> Self { + let src = self.to_array(); + let idx_arr = idx.to_array(); + let mut out = [0u8; 32]; + for lane in 0..2 { + let base = lane * 16; + for i in 0..16 { + let ix = idx_arr[base + i]; + // High bit set → zero; low 4 bits → index within lane. + out[base + i] = if ix & 0x80 != 0 { + 0 + } else { + src[base + (ix & 0x0F) as usize] + }; + } + } + Self::from_array(out) + } + + /// Bitmask-driven blend: select `b[i]` where bit `i` of `mask` is set, + /// else `a[i]`. + /// + /// Uses a 32-bit bitmask (one bit per lane), matching the U8x32 width. + /// + /// # Examples + /// ```rust + /// # #[cfg(feature = "nightly-simd")] { + /// use ndarray::simd_nightly::U8x32; + /// let a = U8x32::splat(10); + /// let b = U8x32::splat(20); + /// // Bit 0 set → lane 0 comes from b. + /// let r = U8x32::mask_blend(1u32, a, b); + /// assert_eq!(r.to_array()[0], 20); + /// assert_eq!(r.to_array()[1], 10); + /// # } + /// ``` + #[inline] + pub fn mask_blend(mask: u32, a: Self, b: Self) -> Self { + let a_arr = a.to_array(); + let b_arr = b.to_array(); + let mut out = [0u8; 32]; + for i in 0..32 { + out[i] = if mask & (1u32 << i) != 0 { b_arr[i] } else { a_arr[i] }; + } + Self::from_array(out) + } + + /// Interleave low bytes within each 128-bit lane. + /// + /// For each of the 2 × 16-byte lanes, takes the first 8 bytes of + /// `self` and `other` and interleaves them: + /// `[self[0], other[0], self[1], other[1], ..., self[7], other[7]]`. + /// + /// Mirrors `_mm256_unpacklo_epi8` semantics. + /// + /// # Examples + /// ```rust + /// # #[cfg(feature = "nightly-simd")] { + /// use ndarray::simd_nightly::U8x32; + /// let a = U8x32::splat(0xAA); + /// let b = U8x32::splat(0xBB); + /// let r = a.unpack_lo_epi8(b); + /// assert_eq!(r.to_array()[0], 0xAA); + /// assert_eq!(r.to_array()[1], 0xBB); + /// # } + /// ``` + #[inline] + pub fn unpack_lo_epi8(self, other: Self) -> Self { + let a = self.to_array(); + let b = other.to_array(); + let mut out = [0u8; 32]; + // 2 × 128-bit lanes, each 16 bytes wide. + for lane in 0..2 { + let base = lane * 16; + for i in 0..8 { + out[base + i * 2] = a[base + i]; + out[base + i * 2 + 1] = b[base + i]; + } + } + Self::from_array(out) + } + + /// Interleave high bytes within each 128-bit lane. + /// + /// For each of the 2 × 16-byte lanes, takes the last 8 bytes of + /// `self` and `other` and interleaves them: + /// `[self[8], other[8], self[9], other[9], ..., self[15], other[15]]`. + /// + /// Mirrors `_mm256_unpackhi_epi8` semantics. + /// + /// # Examples + /// ```rust + /// # #[cfg(feature = "nightly-simd")] { + /// use ndarray::simd_nightly::U8x32; + /// let a = U8x32::splat(0xAA); + /// let b = U8x32::splat(0xBB); + /// let r = a.unpack_hi_epi8(b); + /// assert_eq!(r.to_array()[0], 0xAA); + /// assert_eq!(r.to_array()[1], 0xBB); + /// # } + /// ``` + #[inline] + pub fn unpack_hi_epi8(self, other: Self) -> Self { + let a = self.to_array(); + let b = other.to_array(); + let mut out = [0u8; 32]; + for lane in 0..2 { + let base = lane * 16; + for i in 0..8 { + out[base + i * 2] = a[base + 8 + i]; + out[base + i * 2 + 1] = b[base + 8 + i]; + } + } + Self::from_array(out) + } +} + +// ════════════════════════════════════════════════════════════════════ +// Tests +// ════════════════════════════════════════════════════════════════════ + +#[cfg(test)] +mod tests { + use super::*; + + // ── U8x64 tests ───────────────────────────────────────────────── + + #[test] + fn u8x64_permute_bytes_identity() { + let v = U8x64::from_array(core::array::from_fn(|i| i as u8)); + let idx = U8x64::from_array(core::array::from_fn(|i| i as u8)); + assert_eq!(v.permute_bytes(idx).to_array(), v.to_array()); + } + + #[test] + fn u8x64_permute_bytes_reverse() { + let v = U8x64::from_array(core::array::from_fn(|i| i as u8)); + let idx = U8x64::from_array(core::array::from_fn(|i| (63 - i) as u8)); + let r = v.permute_bytes(idx); + for i in 0..64 { + assert_eq!(r.to_array()[i], (63 - i) as u8, "lane {i}"); + } + } + + #[test] + fn u8x64_shuffle_bytes_lut() { + let lut = U8x64::nibble_popcount_lut(); + // Index 0..16 → popcount of nibble. + let idx_arr: [u8; 64] = core::array::from_fn(|i| (i % 16) as u8); + let idx = U8x64::from_array(idx_arr); + let r = lut.shuffle_bytes(idx); + let expected = [0u8, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4]; + for lane in 0..4 { + let base = lane * 16; + for i in 0..16 { + assert_eq!(r.to_array()[base + i], expected[i], "lane {lane} pos {i}"); + } + } + } + + #[test] + fn u8x64_shuffle_bytes_high_bit_zeroes() { + let lut = U8x64::nibble_popcount_lut(); + let idx = U8x64::splat(0x80); // high bit set → zero + assert!(lut.shuffle_bytes(idx).to_array().iter().all(|&b| b == 0)); + } + + #[test] + fn u8x64_mask_blend_all_a() { + let a = U8x64::splat(10); + let b = U8x64::splat(20); + let r = U8x64::mask_blend(0u64, a, b); + assert!(r.to_array().iter().all(|&x| x == 10)); + } + + #[test] + fn u8x64_mask_blend_all_b() { + let a = U8x64::splat(10); + let b = U8x64::splat(20); + let r = U8x64::mask_blend(u64::MAX, a, b); + assert!(r.to_array().iter().all(|&x| x == 20)); + } + + #[test] + fn u8x64_mask_blend_lane0_from_b() { + let a = U8x64::splat(10); + let b = U8x64::splat(20); + let r = U8x64::mask_blend(1u64, a, b); + assert_eq!(r.to_array()[0], 20); + assert_eq!(r.to_array()[1], 10); + } + + #[test] + fn u8x64_unpack_lo_interleave() { + let a = U8x64::splat(0xAA); + let b = U8x64::splat(0xBB); + let r = a.unpack_lo_epi8(b); + let arr = r.to_array(); + for lane in 0..4 { + let base = lane * 16; + for i in 0..8 { + assert_eq!(arr[base + i * 2], 0xAA, "lane {lane} pos {}", i * 2); + assert_eq!(arr[base + i * 2 + 1], 0xBB, "lane {lane} pos {}", i * 2 + 1); + } + } + } + + #[test] + fn u8x64_unpack_hi_interleave() { + let mut a_arr = [0u8; 64]; + let mut b_arr = [0u8; 64]; + for lane in 0..4 { + for i in 0..16 { + a_arr[lane * 16 + i] = i as u8; + b_arr[lane * 16 + i] = (i + 100) as u8; + } + } + let a = U8x64::from_array(a_arr); + let b = U8x64::from_array(b_arr); + let r = a.unpack_hi_epi8(b); + let arr = r.to_array(); + for lane in 0..4 { + let base = lane * 16; + for i in 0..8 { + assert_eq!(arr[base + i * 2], (8 + i) as u8, "lane {lane} a[{i}]"); + assert_eq!(arr[base + i * 2 + 1], (108 + i) as u8, "lane {lane} b[{i}]"); + } + } + } + + // ── U8x32 tests ───────────────────────────────────────────────── + + #[test] + fn u8x32_permute_bytes_reverse() { + let v = U8x32::from_array(core::array::from_fn(|i| i as u8)); + let idx = U8x32::from_array(core::array::from_fn(|i| (31 - i) as u8)); + let r = v.permute_bytes(idx); + for i in 0..32 { + assert_eq!(r.to_array()[i], (31 - i) as u8, "lane {i}"); + } + } + + #[test] + fn u8x32_shuffle_bytes_lut() { + let lut = U8x32::nibble_popcount_lut(); + let idx_arr: [u8; 32] = core::array::from_fn(|i| (i % 16) as u8); + let idx = U8x32::from_array(idx_arr); + let r = lut.shuffle_bytes(idx); + let expected = [0u8, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4]; + for lane in 0..2 { + let base = lane * 16; + for i in 0..16 { + assert_eq!(r.to_array()[base + i], expected[i], "lane {lane} pos {i}"); + } + } + } + + #[test] + fn u8x32_shuffle_bytes_high_bit_zeroes() { + let lut = U8x32::nibble_popcount_lut(); + let idx = U8x32::splat(0x80); + assert!(lut.shuffle_bytes(idx).to_array().iter().all(|&b| b == 0)); + } + + #[test] + fn u8x32_mask_blend_all_a() { + let a = U8x32::splat(10); + let b = U8x32::splat(20); + let r = U8x32::mask_blend(0u32, a, b); + assert!(r.to_array().iter().all(|&x| x == 10)); + } + + #[test] + fn u8x32_mask_blend_all_b() { + let a = U8x32::splat(10); + let b = U8x32::splat(20); + let r = U8x32::mask_blend(u32::MAX, a, b); + assert!(r.to_array().iter().all(|&x| x == 20)); + } + + #[test] + fn u8x32_mask_blend_lane0_from_b() { + let a = U8x32::splat(10); + let b = U8x32::splat(20); + let r = U8x32::mask_blend(1u32, a, b); + assert_eq!(r.to_array()[0], 20); + assert_eq!(r.to_array()[1], 10); + } + + #[test] + fn u8x32_unpack_lo_interleave() { + let a = U8x32::splat(0xAA); + let b = U8x32::splat(0xBB); + let r = a.unpack_lo_epi8(b); + let arr = r.to_array(); + for lane in 0..2 { + let base = lane * 16; + for i in 0..8 { + assert_eq!(arr[base + i * 2], 0xAA); + assert_eq!(arr[base + i * 2 + 1], 0xBB); + } + } + } + + #[test] + fn u8x32_unpack_hi_interleave() { + let mut a_arr = [0u8; 32]; + let mut b_arr = [0u8; 32]; + for lane in 0..2 { + for i in 0..16 { + a_arr[lane * 16 + i] = i as u8; + b_arr[lane * 16 + i] = (i + 100) as u8; + } + } + let a = U8x32::from_array(a_arr); + let b = U8x32::from_array(b_arr); + let r = a.unpack_hi_epi8(b); + let arr = r.to_array(); + for lane in 0..2 { + let base = lane * 16; + for i in 0..8 { + assert_eq!(arr[base + i * 2], (8 + i) as u8, "lane {lane} a[{i}]"); + assert_eq!(arr[base + i * 2 + 1], (108 + i) as u8, "lane {lane} b[{i}]"); + } + } + } +} diff --git a/src/simd_nightly/f16_types.rs b/src/simd_nightly/f16_types.rs new file mode 100644 index 00000000..4b134beb --- /dev/null +++ b/src/simd_nightly/f16_types.rs @@ -0,0 +1,251 @@ +//! F16x16 portable-simd wrapper (scalar IEEE-754 binary16 emulation — no +//! core::simd half-precision). Round-3-portable-simd agent #9. +#![cfg(feature = "nightly-simd")] + +/// 16-lane IEEE-754 binary16 (half-precision) vector backed by `[u16; 16]`. +/// +/// `core::simd` has no native `f16` lane type, so this is a full scalar +/// emulation. All conversions use the same IEEE-754-correct +/// round-to-nearest-even logic as `src/hpc/quantized.rs` (see +/// `F16::from_f32_rounded`, lines 193-264, and `F16::to_f32`, lines 267-301). +#[derive(Copy, Clone, Debug)] +#[repr(transparent)] +pub struct F16x16(pub [u16; 16]); + +impl F16x16 { + /// Number of lanes. + pub const LANES: usize = 16; + + // ── IEEE-754 binary16 scalar helpers ──────────────────────────────────── + // + // Copied from src/hpc/quantized.rs, `F16::from_f32_rounded` (lines 193-264) + // and `F16::to_f32` (lines 267-301). + + /// Convert a single `f32` value to its IEEE-754 binary16 bit pattern + /// (round-to-nearest-even). + /// + /// Logic mirrors `F16::from_f32_rounded` in `src/hpc/quantized.rs:193`. + #[inline] + fn f32_to_f16_bits(v: f32) -> u16 { + let bits = v.to_bits(); + let sign = ((bits >> 16) & 0x8000) as u16; + let exp = ((bits >> 23) & 0xFF) as i32; + let mant = bits & 0x007F_FFFF; + + if exp == 0xFF { + // Inf / NaN + if mant == 0 { + return sign | 0x7C00; + } + let mant16 = (mant >> 13) as u16; + let mant16 = if mant16 == 0 { 0x200 } else { mant16 | 0x200 }; + return sign | 0x7C00 | mant16; + } + + let new_exp = exp - 127 + 15; + + if new_exp >= 0x1F { + return sign | 0x7C00; // overflow → Inf + } + if new_exp >= 1 { + // Normal range — round-to-nearest-even on the 13 dropped bits. + let half_bits = mant & 0x0000_1FFF; + let truncated = (mant >> 13) as u32; + let half = 0x1000u32; + let round_up = if half_bits > half { + 1u32 + } else if half_bits < half { + 0 + } else { + truncated & 1 // tie → round to even + }; + let mant16 = truncated + round_up; + if mant16 == 0x400 { + // mantissa overflow bumps exponent + let new_exp = new_exp + 1; + if new_exp >= 0x1F { + return sign | 0x7C00; + } + return sign | ((new_exp as u16) << 10); + } + return sign | ((new_exp as u16) << 10) | (mant16 as u16); + } + + // Subnormal / underflow + if new_exp < -10 { + return sign; // underflow → ±0 + } + let mant_full = mant | 0x0080_0000; // 24-bit with implicit 1 + let shift = (14 - new_exp) as u32; + let truncated = mant_full >> shift; + let dropped_mask = (1u32 << shift) - 1; + let dropped = mant_full & dropped_mask; + let half = 1u32 << (shift - 1); + let round_up = if dropped > half { + 1u32 + } else if dropped < half { + 0 + } else { + truncated & 1 + }; + let mant16 = (truncated + round_up) as u16; + sign | mant16 + } + + /// Convert an IEEE-754 binary16 bit pattern to `f32` (lossless). + /// + /// Logic mirrors `F16::to_f32` in `src/hpc/quantized.rs:267`. + #[inline] + fn f16_bits_to_f32(h: u16) -> f32 { + let h = h as u32; + let sign = (h & 0x8000) << 16; + let exp = (h >> 10) & 0x1F; + let mant = h & 0x03FF; + + let bits = if exp == 0 { + if mant == 0 { + sign // ±0 + } else { + // Subnormal: normalize + let mut m = mant; + let mut e: i32 = 1; + while (m & 0x0400) == 0 { + m <<= 1; + e -= 1; + } + let m = m & 0x03FF; + let new_exp = (e - 1 + 127 - 14) as u32; + sign | (new_exp << 23) | (m << 13) + } + } else if exp == 0x1F { + if mant == 0 { + sign | 0x7F80_0000 // ±Inf + } else { + sign | 0x7F80_0000 | (mant << 13) // NaN + } + } else { + let new_exp = exp + (127 - 15); + sign | (new_exp << 23) | (mant << 13) + }; + f32::from_bits(bits) + } + + // ── Constructors ──────────────────────────────────────────────────────── + + /// Broadcast `v` (converted from `f32` to IEEE-754 binary16) across all 16 lanes. + #[inline] + pub fn splat(v: f32) -> Self { + F16x16([Self::f32_to_f16_bits(v); 16]) + } + + /// Load 16 binary16 values from a `u16` slice (must have `len >= 16`). + #[inline] + pub fn from_slice(s: &[u16]) -> Self { + assert!(s.len() >= 16, "F16x16::from_slice: need >= 16 elements, got {}", s.len()); + let mut arr = [0u16; 16]; + arr.copy_from_slice(&s[..16]); + F16x16(arr) + } + + /// Construct directly from a raw `[u16; 16]` array of binary16 bit patterns. + #[inline] + pub fn from_array(arr: [u16; 16]) -> Self { + F16x16(arr) + } + + /// Return the raw `[u16; 16]` array of binary16 bit patterns. + #[inline] + pub fn to_array(self) -> [u16; 16] { + self.0 + } + + /// Write all 16 binary16 bit patterns into a `u16` slice (must have `len >= 16`). + #[inline] + pub fn copy_to_slice(self, s: &mut [u16]) { + assert!(s.len() >= 16, "F16x16::copy_to_slice: need >= 16 elements, got {}", s.len()); + s[..16].copy_from_slice(&self.0); + } + + // ── Conversions ───────────────────────────────────────────────────────── + + /// Upcast all 16 binary16 lanes to `f32`. + #[inline] + pub fn to_f32_array(self) -> [f32; 16] { + let mut out = [0.0f32; 16]; + for i in 0..16 { + out[i] = Self::f16_bits_to_f32(self.0[i]); + } + out + } + + /// Convert 16 `f32` values to binary16 and pack into a new `F16x16`. + #[inline] + pub fn from_f32_array(arr: [f32; 16]) -> Self { + let mut out = [0u16; 16]; + for i in 0..16 { + out[i] = Self::f32_to_f16_bits(arr[i]); + } + F16x16(out) + } +} + +// ============================================================================ +// Tests +// ============================================================================ + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn splat_roundtrip() { + let v = F16x16::splat(1.0_f32); + let arr = v.to_f32_array(); + for lane in arr { + assert!((lane - 1.0).abs() < 1e-5, "splat(1.0) lane = {lane}"); + } + } + + #[test] + fn from_array_to_array_identity() { + let raw = [0x3C00u16; 16]; // 1.0 in binary16 + let v = F16x16::from_array(raw); + assert_eq!(v.to_array(), raw); + } + + #[test] + fn from_slice_copy_to_slice_roundtrip() { + let src: [u16; 16] = core::array::from_fn(|i| i as u16 + 0x3C00); + let v = F16x16::from_slice(&src); + let mut dst = [0u16; 16]; + v.copy_to_slice(&mut dst); + assert_eq!(dst, src); + } + + #[test] + fn from_f32_array_to_f32_array_roundtrip() { + let inputs: [f32; 16] = core::array::from_fn(|i| i as f32 * 0.5); + let v = F16x16::from_f32_array(inputs); + let out = v.to_f32_array(); + for (i, (&orig, &back)) in inputs.iter().zip(out.iter()).enumerate() { + assert!((orig - back).abs() < 0.001, "lane {i}: {orig} → f16 → {back}"); + } + } + + #[test] + fn special_values_inf_nan() { + let v = F16x16::splat(f32::INFINITY); + for lane in v.to_f32_array() { + assert!(lane.is_infinite() && lane > 0.0); + } + let v = F16x16::splat(f32::NAN); + for lane in v.to_f32_array() { + assert!(lane.is_nan()); + } + } + + #[test] + fn lanes_constant() { + assert_eq!(F16x16::LANES, 16); + } +} diff --git a/src/simd_nightly/f32_types.rs b/src/simd_nightly/f32_types.rs new file mode 100644 index 00000000..3370cf34 --- /dev/null +++ b/src/simd_nightly/f32_types.rs @@ -0,0 +1,395 @@ +//! F32x16 / F32x8 portable-simd wrappers — round-3-portable-simd agent #1. +#![cfg(feature = "nightly-simd")] + +use core::simd::cmp::{SimdPartialEq, SimdPartialOrd}; +use core::simd::num::SimdFloat; +use core::simd::{f32x16 as core_f32x16, f32x8 as core_f32x8}; +// `mul_add`, `sqrt`, `round`, `floor`, `abs` live in `StdFloat` (std-only nightly trait). +use std::simd::StdFloat; + +use super::masks::{F32Mask16, F32Mask8}; +use super::u_word_types::{U32x16, U32x8}; + +// ════════════════════════════════════════════════════════════════════ +// F32x16 — 16-lane single-precision float +// ════════════════════════════════════════════════════════════════════ + +/// 16-lane `f32` SIMD vector backed by `core::simd::f32x16`. +/// +/// API mirrors `simd_avx512::F32x16` so consumer code is identical. +/// Miri can execute every method below — unlike the intrinsics +/// backend, where SIMD paths are opaque to miri. +#[derive(Copy, Clone, Debug, PartialEq)] +#[repr(transparent)] +pub struct F32x16(pub core_f32x16); + +impl F32x16 { + pub const LANES: usize = 16; + + // ── Constructors ────────────────────────────────────────────── + + /// Broadcast `v` to all 16 lanes. + #[inline(always)] + pub fn splat(v: f32) -> Self { + Self(core_f32x16::splat(v)) + } + + /// Load from the first 16 elements of `arr`. + #[inline(always)] + pub fn from_array(arr: [f32; 16]) -> Self { + Self(core_f32x16::from_array(arr)) + } + + /// Load from the first 16 elements of `s`. + /// + /// # Panics + /// Panics if `s.len() < 16`. + #[inline(always)] + pub fn from_slice(s: &[f32]) -> Self { + assert!(s.len() >= 16, "F32x16::from_slice needs >= 16 elements"); + Self(core_f32x16::from_slice(s)) + } + + /// Copy all 16 lanes into a `[f32; 16]`. + #[inline(always)] + pub fn to_array(self) -> [f32; 16] { + self.0.to_array() + } + + /// Store all 16 lanes into the first 16 slots of `s`. + /// + /// # Panics + /// Panics if `s.len() < 16`. + #[inline(always)] + pub fn copy_to_slice(self, s: &mut [f32]) { + assert!(s.len() >= 16, "F32x16::copy_to_slice needs >= 16 elements"); + self.0.copy_to_slice(s); + } + + // ── Reductions ──────────────────────────────────────────────── + + /// Sum of all 16 lanes. + #[inline(always)] + pub fn reduce_sum(self) -> f32 { + self.0.reduce_sum() + } + + /// Minimum lane value. + #[inline(always)] + pub fn reduce_min(self) -> f32 { + self.0.reduce_min() + } + + /// Maximum lane value. + #[inline(always)] + pub fn reduce_max(self) -> f32 { + self.0.reduce_max() + } + + // ── Lane-wise min / max / clamp ─────────────────────────────── + + /// Per-lane minimum. + #[inline(always)] + pub fn simd_min(self, other: Self) -> Self { + Self(self.0.simd_min(other.0)) + } + + /// Per-lane maximum. + #[inline(always)] + pub fn simd_max(self, other: Self) -> Self { + Self(self.0.simd_max(other.0)) + } + + /// Per-lane clamp: `lo <= self <= hi` for each lane. + #[inline(always)] + pub fn simd_clamp(self, lo: Self, hi: Self) -> Self { + Self(self.0.simd_clamp(lo.0, hi.0)) + } + + // ── FMA + math ──────────────────────────────────────────────── + + /// Fused multiply-add: `self * b + c`. + /// + /// Maps to `_mm512_fmadd_ps` on AVX-512 builds via LLVM's portable-simd + /// lowering; scalar `f32::mul_add` per lane otherwise. + #[inline(always)] + pub fn mul_add(self, b: Self, c: Self) -> Self { + Self(self.0.mul_add(b.0, c.0)) + } + + /// Per-lane square root. + #[inline(always)] + pub fn sqrt(self) -> Self { + Self(self.0.sqrt()) + } + + /// Per-lane round-to-nearest (ties to even). + #[inline(always)] + pub fn round(self) -> Self { + Self(self.0.round()) + } + + /// Per-lane floor (toward negative infinity). + #[inline(always)] + pub fn floor(self) -> Self { + Self(self.0.floor()) + } + + /// Per-lane absolute value (clears sign bit). + #[inline(always)] + pub fn abs(self) -> Self { + Self(self.0.abs()) + } + + // ── Bit reinterpretation ────────────────────────────────────── + + /// Reinterpret the 16 × 32-bit float lanes as `U32x16` (no conversion). + #[inline(always)] + pub fn to_bits(self) -> U32x16 { + U32x16(self.0.to_bits()) + } + + /// Reinterpret a `U32x16` as `F32x16` (no conversion). + #[inline(always)] + pub fn from_bits(bits: U32x16) -> Self { + Self(core_f32x16::from_bits(bits.0)) + } + + // ── Comparisons → typed masks ───────────────────────────────── + // + // Return `super::masks::F32Mask16` (agent #7's type), which wraps + // `core::simd::Mask`. + + /// Per-lane equality: `self[i] == other[i]`. + #[inline(always)] + pub fn simd_eq(self, other: Self) -> F32Mask16 { + F32Mask16(self.0.simd_eq(other.0)) + } + + /// Per-lane inequality: `self[i] != other[i]`. + #[inline(always)] + pub fn simd_ne(self, other: Self) -> F32Mask16 { + F32Mask16(self.0.simd_ne(other.0)) + } + + /// Per-lane less-than: `self[i] < other[i]`. + #[inline(always)] + pub fn simd_lt(self, other: Self) -> F32Mask16 { + F32Mask16(self.0.simd_lt(other.0)) + } + + /// Per-lane less-or-equal: `self[i] <= other[i]`. + #[inline(always)] + pub fn simd_le(self, other: Self) -> F32Mask16 { + F32Mask16(self.0.simd_le(other.0)) + } + + /// Per-lane greater-than: `self[i] > other[i]`. + #[inline(always)] + pub fn simd_gt(self, other: Self) -> F32Mask16 { + F32Mask16(self.0.simd_gt(other.0)) + } + + /// Per-lane greater-or-equal: `self[i] >= other[i]`. + #[inline(always)] + pub fn simd_ge(self, other: Self) -> F32Mask16 { + F32Mask16(self.0.simd_ge(other.0)) + } +} + +impl Default for F32x16 { + #[inline(always)] + fn default() -> Self { + Self::splat(0.0) + } +} + +// ════════════════════════════════════════════════════════════════════ +// F32x8 — 8-lane single-precision float +// ════════════════════════════════════════════════════════════════════ + +/// 8-lane `f32` SIMD vector backed by `core::simd::f32x8`. +/// +/// API mirrors `simd_avx512::F32x16` / `F32x8` so consumer code is +/// identical. Miri can execute every method below. +#[derive(Copy, Clone, Debug, PartialEq)] +#[repr(transparent)] +pub struct F32x8(pub core_f32x8); + +impl F32x8 { + pub const LANES: usize = 8; + + // ── Constructors ────────────────────────────────────────────── + + /// Broadcast `v` to all 8 lanes. + #[inline(always)] + pub fn splat(v: f32) -> Self { + Self(core_f32x8::splat(v)) + } + + /// Load from the first 8 elements of `arr`. + #[inline(always)] + pub fn from_array(arr: [f32; 8]) -> Self { + Self(core_f32x8::from_array(arr)) + } + + /// Load from the first 8 elements of `s`. + /// + /// # Panics + /// Panics if `s.len() < 8`. + #[inline(always)] + pub fn from_slice(s: &[f32]) -> Self { + assert!(s.len() >= 8, "F32x8::from_slice needs >= 8 elements"); + Self(core_f32x8::from_slice(s)) + } + + /// Copy all 8 lanes into a `[f32; 8]`. + #[inline(always)] + pub fn to_array(self) -> [f32; 8] { + self.0.to_array() + } + + /// Store all 8 lanes into the first 8 slots of `s`. + /// + /// # Panics + /// Panics if `s.len() < 8`. + #[inline(always)] + pub fn copy_to_slice(self, s: &mut [f32]) { + assert!(s.len() >= 8, "F32x8::copy_to_slice needs >= 8 elements"); + self.0.copy_to_slice(s); + } + + // ── Reductions ──────────────────────────────────────────────── + + /// Sum of all 8 lanes. + #[inline(always)] + pub fn reduce_sum(self) -> f32 { + self.0.reduce_sum() + } + + /// Minimum lane value. + #[inline(always)] + pub fn reduce_min(self) -> f32 { + self.0.reduce_min() + } + + /// Maximum lane value. + #[inline(always)] + pub fn reduce_max(self) -> f32 { + self.0.reduce_max() + } + + // ── Lane-wise min / max / clamp ─────────────────────────────── + + /// Per-lane minimum. + #[inline(always)] + pub fn simd_min(self, other: Self) -> Self { + Self(self.0.simd_min(other.0)) + } + + /// Per-lane maximum. + #[inline(always)] + pub fn simd_max(self, other: Self) -> Self { + Self(self.0.simd_max(other.0)) + } + + /// Per-lane clamp: `lo <= self <= hi` for each lane. + #[inline(always)] + pub fn simd_clamp(self, lo: Self, hi: Self) -> Self { + Self(self.0.simd_clamp(lo.0, hi.0)) + } + + // ── FMA + math ──────────────────────────────────────────────── + + /// Fused multiply-add: `self * b + c`. + #[inline(always)] + pub fn mul_add(self, b: Self, c: Self) -> Self { + Self(self.0.mul_add(b.0, c.0)) + } + + /// Per-lane square root. + #[inline(always)] + pub fn sqrt(self) -> Self { + Self(self.0.sqrt()) + } + + /// Per-lane round-to-nearest (ties to even). + #[inline(always)] + pub fn round(self) -> Self { + Self(self.0.round()) + } + + /// Per-lane floor (toward negative infinity). + #[inline(always)] + pub fn floor(self) -> Self { + Self(self.0.floor()) + } + + /// Per-lane absolute value (clears sign bit). + #[inline(always)] + pub fn abs(self) -> Self { + Self(self.0.abs()) + } + + // ── Bit reinterpretation ────────────────────────────────────── + + /// Reinterpret the 8 × 32-bit float lanes as `U32x8` (no conversion). + #[inline(always)] + pub fn to_bits(self) -> U32x8 { + U32x8(self.0.to_bits()) + } + + /// Reinterpret a `U32x8` as `F32x8` (no conversion). + #[inline(always)] + pub fn from_bits(bits: U32x8) -> Self { + Self(core_f32x8::from_bits(bits.0)) + } + + // ── Comparisons → typed masks ───────────────────────────────── + // + // Return `super::masks::F32Mask8` (agent #7's type), which wraps + // `core::simd::Mask`. + + /// Per-lane equality: `self[i] == other[i]`. + #[inline(always)] + pub fn simd_eq(self, other: Self) -> F32Mask8 { + F32Mask8(self.0.simd_eq(other.0)) + } + + /// Per-lane inequality: `self[i] != other[i]`. + #[inline(always)] + pub fn simd_ne(self, other: Self) -> F32Mask8 { + F32Mask8(self.0.simd_ne(other.0)) + } + + /// Per-lane less-than: `self[i] < other[i]`. + #[inline(always)] + pub fn simd_lt(self, other: Self) -> F32Mask8 { + F32Mask8(self.0.simd_lt(other.0)) + } + + /// Per-lane less-or-equal: `self[i] <= other[i]`. + #[inline(always)] + pub fn simd_le(self, other: Self) -> F32Mask8 { + F32Mask8(self.0.simd_le(other.0)) + } + + /// Per-lane greater-than: `self[i] > other[i]`. + #[inline(always)] + pub fn simd_gt(self, other: Self) -> F32Mask8 { + F32Mask8(self.0.simd_gt(other.0)) + } + + /// Per-lane greater-or-equal: `self[i] >= other[i]`. + #[inline(always)] + pub fn simd_ge(self, other: Self) -> F32Mask8 { + F32Mask8(self.0.simd_ge(other.0)) + } +} + +impl Default for F32x8 { + #[inline(always)] + fn default() -> Self { + Self::splat(0.0) + } +} diff --git a/src/simd_nightly/f64_types.rs b/src/simd_nightly/f64_types.rs new file mode 100644 index 00000000..3d85881c --- /dev/null +++ b/src/simd_nightly/f64_types.rs @@ -0,0 +1,346 @@ +//! F64x8 / F64x4 portable-simd wrappers — round-3-portable-simd agent #2. +#![cfg(feature = "nightly-simd")] + +use core::simd::cmp::{SimdPartialEq, SimdPartialOrd}; +use core::simd::num::SimdFloat; +use core::simd::{f64x4 as core_f64x4, f64x8 as core_f64x8}; +use std::simd::StdFloat; + +// ════════════════════════════════════════════════════════════════════ +// F64x8 — 8-lane double-precision float +// ════════════════════════════════════════════════════════════════════ + +/// 8-lane `f64` SIMD vector backed by `core::simd::f64x8`. +/// +/// API mirrors `simd_avx512::F64x8` so consumer code is identical. +/// Miri can execute every method — unlike the intrinsics backend where +/// SIMD paths are opaque to miri. +#[derive(Copy, Clone, Debug)] +#[repr(transparent)] +pub struct F64x8(pub core_f64x8); + +impl F64x8 { + /// Number of f64 lanes. + pub const LANES: usize = 8; + + // ── Constructors ─────────────────────────────────────────────── + + /// Broadcast a scalar to all 8 lanes. + #[inline(always)] + pub fn splat(v: f64) -> Self { + Self(core_f64x8::splat(v)) + } + + /// Load from the first 8 elements of `s`. Panics if `s.len() < 8`. + #[inline(always)] + pub fn from_slice(s: &[f64]) -> Self { + assert!(s.len() >= 8, "F64x8::from_slice needs ≥8 elements"); + Self(core_f64x8::from_slice(s)) + } + + /// Load from a fixed-size array. + #[inline(always)] + pub fn from_array(arr: [f64; 8]) -> Self { + Self(core_f64x8::from_array(arr)) + } + + /// Convert to a fixed-size array. + #[inline(always)] + pub fn to_array(self) -> [f64; 8] { + self.0.to_array() + } + + /// Store to the first 8 elements of `s`. Panics if `s.len() < 8`. + #[inline(always)] + pub fn copy_to_slice(self, s: &mut [f64]) { + assert!(s.len() >= 8, "F64x8::copy_to_slice needs ≥8 elements"); + self.0.copy_to_slice(s); + } + + // ── Reductions ───────────────────────────────────────────────── + + /// Horizontal sum of all 8 lanes. + #[inline(always)] + pub fn reduce_sum(self) -> f64 { + self.0.reduce_sum() + } + + /// Horizontal minimum across all 8 lanes. + #[inline(always)] + pub fn reduce_min(self) -> f64 { + self.0.reduce_min() + } + + /// Horizontal maximum across all 8 lanes. + #[inline(always)] + pub fn reduce_max(self) -> f64 { + self.0.reduce_max() + } + + // ── Lane-wise min / max / clamp ──────────────────────────────── + + /// Per-lane minimum. + #[inline(always)] + pub fn simd_min(self, other: Self) -> Self { + Self(self.0.simd_min(other.0)) + } + + /// Per-lane maximum. + #[inline(always)] + pub fn simd_max(self, other: Self) -> Self { + Self(self.0.simd_max(other.0)) + } + + /// Per-lane clamp to `[lo, hi]`. + #[inline(always)] + pub fn simd_clamp(self, lo: Self, hi: Self) -> Self { + Self(self.0.simd_clamp(lo.0, hi.0)) + } + + // ── FMA + math ───────────────────────────────────────────────── + + /// Fused multiply-add: `self * b + c`. + #[inline(always)] + pub fn mul_add(self, b: Self, c: Self) -> Self { + Self(self.0.mul_add(b.0, c.0)) + } + + /// Per-lane square root. + #[inline(always)] + pub fn sqrt(self) -> Self { + Self(self.0.sqrt()) + } + + /// Per-lane round-to-nearest-even. + #[inline(always)] + pub fn round(self) -> Self { + Self(self.0.round()) + } + + /// Per-lane floor. + #[inline(always)] + pub fn floor(self) -> Self { + Self(self.0.floor()) + } + + /// Per-lane absolute value. + #[inline(always)] + pub fn abs(self) -> Self { + Self(self.0.abs()) + } + + // ── Bit reinterpretation ─────────────────────────────────────── + + /// Reinterpret the bit pattern as a `U64x8`. + #[inline(always)] + pub fn to_bits(self) -> super::u_word_types::U64x8 { + super::u_word_types::U64x8(self.0.to_bits()) + } + + // ── Comparisons → mask ───────────────────────────────────────── + + /// Per-lane `==`. + #[inline(always)] + pub fn simd_eq(self, other: Self) -> super::masks::F64Mask8 { + super::masks::F64Mask8(self.0.simd_eq(other.0)) + } + + /// Per-lane `!=`. + #[inline(always)] + pub fn simd_ne(self, other: Self) -> super::masks::F64Mask8 { + super::masks::F64Mask8(self.0.simd_ne(other.0)) + } + + /// Per-lane `<`. + #[inline(always)] + pub fn simd_lt(self, other: Self) -> super::masks::F64Mask8 { + super::masks::F64Mask8(self.0.simd_lt(other.0)) + } + + /// Per-lane `<=`. + #[inline(always)] + pub fn simd_le(self, other: Self) -> super::masks::F64Mask8 { + super::masks::F64Mask8(self.0.simd_le(other.0)) + } + + /// Per-lane `>`. + #[inline(always)] + pub fn simd_gt(self, other: Self) -> super::masks::F64Mask8 { + super::masks::F64Mask8(self.0.simd_gt(other.0)) + } + + /// Per-lane `>=`. + #[inline(always)] + pub fn simd_ge(self, other: Self) -> super::masks::F64Mask8 { + super::masks::F64Mask8(self.0.simd_ge(other.0)) + } +} + +// ════════════════════════════════════════════════════════════════════ +// F64x4 — 4-lane double-precision float +// ════════════════════════════════════════════════════════════════════ + +/// 4-lane `f64` SIMD vector backed by `core::simd::f64x4`. +/// +/// API mirrors `simd_avx512::F64x4` at half the F64x8 width. +/// Miri-executable via the `core::simd` portable backend. +#[derive(Copy, Clone, Debug)] +#[repr(transparent)] +pub struct F64x4(pub core_f64x4); + +impl F64x4 { + /// Number of f64 lanes. + pub const LANES: usize = 4; + + // ── Constructors ─────────────────────────────────────────────── + + /// Broadcast a scalar to all 4 lanes. + #[inline(always)] + pub fn splat(v: f64) -> Self { + Self(core_f64x4::splat(v)) + } + + /// Load from the first 4 elements of `s`. Panics if `s.len() < 4`. + #[inline(always)] + pub fn from_slice(s: &[f64]) -> Self { + assert!(s.len() >= 4, "F64x4::from_slice needs ≥4 elements"); + Self(core_f64x4::from_slice(s)) + } + + /// Load from a fixed-size array. + #[inline(always)] + pub fn from_array(arr: [f64; 4]) -> Self { + Self(core_f64x4::from_array(arr)) + } + + /// Convert to a fixed-size array. + #[inline(always)] + pub fn to_array(self) -> [f64; 4] { + self.0.to_array() + } + + /// Store to the first 4 elements of `s`. Panics if `s.len() < 4`. + #[inline(always)] + pub fn copy_to_slice(self, s: &mut [f64]) { + assert!(s.len() >= 4, "F64x4::copy_to_slice needs ≥4 elements"); + self.0.copy_to_slice(s); + } + + // ── Reductions ───────────────────────────────────────────────── + + /// Horizontal sum of all 4 lanes. + #[inline(always)] + pub fn reduce_sum(self) -> f64 { + self.0.reduce_sum() + } + + /// Horizontal minimum across all 4 lanes. + #[inline(always)] + pub fn reduce_min(self) -> f64 { + self.0.reduce_min() + } + + /// Horizontal maximum across all 4 lanes. + #[inline(always)] + pub fn reduce_max(self) -> f64 { + self.0.reduce_max() + } + + // ── Lane-wise min / max / clamp ──────────────────────────────── + + /// Per-lane minimum. + #[inline(always)] + pub fn simd_min(self, other: Self) -> Self { + Self(self.0.simd_min(other.0)) + } + + /// Per-lane maximum. + #[inline(always)] + pub fn simd_max(self, other: Self) -> Self { + Self(self.0.simd_max(other.0)) + } + + /// Per-lane clamp to `[lo, hi]`. + #[inline(always)] + pub fn simd_clamp(self, lo: Self, hi: Self) -> Self { + Self(self.0.simd_clamp(lo.0, hi.0)) + } + + // ── FMA + math ───────────────────────────────────────────────── + + /// Fused multiply-add: `self * b + c`. + #[inline(always)] + pub fn mul_add(self, b: Self, c: Self) -> Self { + Self(self.0.mul_add(b.0, c.0)) + } + + /// Per-lane square root. + #[inline(always)] + pub fn sqrt(self) -> Self { + Self(self.0.sqrt()) + } + + /// Per-lane round-to-nearest-even. + #[inline(always)] + pub fn round(self) -> Self { + Self(self.0.round()) + } + + /// Per-lane floor. + #[inline(always)] + pub fn floor(self) -> Self { + Self(self.0.floor()) + } + + /// Per-lane absolute value. + #[inline(always)] + pub fn abs(self) -> Self { + Self(self.0.abs()) + } + + // ── Bit reinterpretation ─────────────────────────────────────── + + /// Reinterpret the bit pattern as a `U64x4`. + #[inline(always)] + pub fn to_bits(self) -> super::u_word_types::U64x4 { + super::u_word_types::U64x4(self.0.to_bits()) + } + + // ── Comparisons → mask ───────────────────────────────────────── + + /// Per-lane `==`. + #[inline(always)] + pub fn simd_eq(self, other: Self) -> super::masks::F64Mask4 { + super::masks::F64Mask4(self.0.simd_eq(other.0)) + } + + /// Per-lane `!=`. + #[inline(always)] + pub fn simd_ne(self, other: Self) -> super::masks::F64Mask4 { + super::masks::F64Mask4(self.0.simd_ne(other.0)) + } + + /// Per-lane `<`. + #[inline(always)] + pub fn simd_lt(self, other: Self) -> super::masks::F64Mask4 { + super::masks::F64Mask4(self.0.simd_lt(other.0)) + } + + /// Per-lane `<=`. + #[inline(always)] + pub fn simd_le(self, other: Self) -> super::masks::F64Mask4 { + super::masks::F64Mask4(self.0.simd_le(other.0)) + } + + /// Per-lane `>`. + #[inline(always)] + pub fn simd_gt(self, other: Self) -> super::masks::F64Mask4 { + super::masks::F64Mask4(self.0.simd_gt(other.0)) + } + + /// Per-lane `>=`. + #[inline(always)] + pub fn simd_ge(self, other: Self) -> super::masks::F64Mask4 { + super::masks::F64Mask4(self.0.simd_ge(other.0)) + } +} diff --git a/src/simd_nightly/i8_types.rs b/src/simd_nightly/i8_types.rs new file mode 100644 index 00000000..dde432e4 --- /dev/null +++ b/src/simd_nightly/i8_types.rs @@ -0,0 +1,266 @@ +//! I8x32 / I8x64 portable-simd wrappers — round-3-portable-simd agent #5. +#![cfg(feature = "nightly-simd")] + +use core::simd::cmp::{SimdOrd, SimdPartialEq, SimdPartialOrd}; +use core::simd::num::SimdInt; +use core::simd::{i8x32, i8x64}; + +// ════════════════════════════════════════════════════════════════════ +// I8x64 — 64-lane signed byte +// Mirrors `simd_avx512::I8x64` surface (AVX-512BW path). +// ════════════════════════════════════════════════════════════════════ + +/// 64-lane `i8` SIMD vector backed by `core::simd::i8x64`. +/// +/// API mirrors `simd_avx512::I8x64` so consumer code is backend-agnostic. +/// Every method executes under miri, unlike the AVX-512 intrinsics path. +/// +/// # Examples +/// +/// ```rust,ignore +/// let a = I8x64::splat(3); +/// let b = I8x64::splat(5); +/// assert_eq!(a.reduce_max(), 3); +/// assert_eq!(a.saturating_add(b).to_array(), [8i8; 64]); +/// ``` +#[derive(Copy, Clone, Debug)] +#[repr(transparent)] +pub struct I8x64(pub i8x64); + +impl I8x64 { + /// Number of `i8` lanes. + pub const LANES: usize = 64; + + /// Broadcast `v` into every lane. + #[inline(always)] + pub fn splat(v: i8) -> Self { + Self(i8x64::splat(v)) + } + + /// Load from a slice of at least 64 elements (panics otherwise). + #[inline(always)] + pub fn from_slice(s: &[i8]) -> Self { + assert!(s.len() >= 64, "I8x64::from_slice needs ≥64 elements"); + Self(i8x64::from_slice(s)) + } + + /// Construct from a fixed-size array. + #[inline(always)] + pub fn from_array(arr: [i8; 64]) -> Self { + Self(i8x64::from_array(arr)) + } + + /// Extract all lanes as a fixed-size array. + #[inline(always)] + pub fn to_array(self) -> [i8; 64] { + self.0.to_array() + } + + /// Store all lanes into `s[0..64]` (panics if `s.len() < 64`). + #[inline(always)] + pub fn copy_to_slice(self, s: &mut [i8]) { + assert!(s.len() >= 64, "I8x64::copy_to_slice needs ≥64 elements"); + self.0.copy_to_slice(s); + } + + // ── Reductions ──────────────────────────────────────────────── + + /// Wrapping horizontal sum of all 64 lanes. + #[inline(always)] + pub fn reduce_sum(self) -> i8 { + self.0.reduce_sum() + } + + /// Minimum across all lanes. + #[inline(always)] + pub fn reduce_min(self) -> i8 { + self.0.reduce_min() + } + + /// Maximum across all lanes. + #[inline(always)] + pub fn reduce_max(self) -> i8 { + self.0.reduce_max() + } + + // ── Lane-wise min / max ─────────────────────────────────────── + + /// Per-lane minimum. + #[inline(always)] + pub fn simd_min(self, other: Self) -> Self { + Self(self.0.simd_min(other.0)) + } + + /// Per-lane maximum. + #[inline(always)] + pub fn simd_max(self, other: Self) -> Self { + Self(self.0.simd_max(other.0)) + } + + // ── Saturating arithmetic ───────────────────────────────────── + + /// Per-lane signed saturating add. Results clamp to `[i8::MIN, i8::MAX]`. + #[inline(always)] + pub fn saturating_add(self, other: Self) -> Self { + Self(self.0.saturating_add(other.0)) + } + + /// Per-lane signed saturating subtract. Results clamp to `[i8::MIN, i8::MAX]`. + #[inline(always)] + pub fn saturating_sub(self, other: Self) -> Self { + Self(self.0.saturating_sub(other.0)) + } + + // ── Comparisons → bitmask ───────────────────────────────────── + + /// Per-lane equality mask. Returns a 64-bit integer; bit i is set iff + /// `self[i] == other[i]`. Matches the `__mmask64` shape of the + /// AVX-512BW backend. + #[inline(always)] + pub fn cmpeq_mask(self, other: Self) -> u64 { + self.0.simd_eq(other.0).to_bitmask() + } + + /// Per-lane signed greater-than mask. Returns a 64-bit integer; bit i is + /// set iff `self[i] > other[i]` (signed comparison). Matches the + /// `_mm512_cmpgt_epi8_mask` shape of the AVX-512BW backend. + #[inline(always)] + pub fn cmpgt_mask(self, other: Self) -> u64 { + self.0.simd_gt(other.0).to_bitmask() + } +} + +impl PartialEq for I8x64 { + fn eq(&self, other: &Self) -> bool { + self.to_array() == other.to_array() + } +} + +// ════════════════════════════════════════════════════════════════════ +// I8x32 — 32-lane signed byte +// Mirrors `simd_avx512::I8x32` surface (AVX2/AVX-512BW path). +// ════════════════════════════════════════════════════════════════════ + +/// 32-lane `i8` SIMD vector backed by `core::simd::i8x32`. +/// +/// API mirrors `simd_avx512::I8x32` so consumer code is backend-agnostic. +/// +/// # Examples +/// +/// ```rust,ignore +/// let a = I8x32::splat(-10); +/// let b = I8x32::splat(20); +/// assert_eq!(a.saturating_add(b).to_array(), [10i8; 32]); +/// ``` +#[derive(Copy, Clone, Debug)] +#[repr(transparent)] +pub struct I8x32(pub i8x32); + +impl I8x32 { + /// Number of `i8` lanes. + pub const LANES: usize = 32; + + /// Broadcast `v` into every lane. + #[inline(always)] + pub fn splat(v: i8) -> Self { + Self(i8x32::splat(v)) + } + + /// Load from a slice of at least 32 elements (panics otherwise). + #[inline(always)] + pub fn from_slice(s: &[i8]) -> Self { + assert!(s.len() >= 32, "I8x32::from_slice needs ≥32 elements"); + Self(i8x32::from_slice(s)) + } + + /// Construct from a fixed-size array. + #[inline(always)] + pub fn from_array(arr: [i8; 32]) -> Self { + Self(i8x32::from_array(arr)) + } + + /// Extract all lanes as a fixed-size array. + #[inline(always)] + pub fn to_array(self) -> [i8; 32] { + self.0.to_array() + } + + /// Store all lanes into `s[0..32]` (panics if `s.len() < 32`). + #[inline(always)] + pub fn copy_to_slice(self, s: &mut [i8]) { + assert!(s.len() >= 32, "I8x32::copy_to_slice needs ≥32 elements"); + self.0.copy_to_slice(s); + } + + // ── Reductions ──────────────────────────────────────────────── + + /// Wrapping horizontal sum of all 32 lanes. + #[inline(always)] + pub fn reduce_sum(self) -> i8 { + self.0.reduce_sum() + } + + /// Minimum across all lanes. + #[inline(always)] + pub fn reduce_min(self) -> i8 { + self.0.reduce_min() + } + + /// Maximum across all lanes. + #[inline(always)] + pub fn reduce_max(self) -> i8 { + self.0.reduce_max() + } + + // ── Lane-wise min / max ─────────────────────────────────────── + + /// Per-lane minimum. + #[inline(always)] + pub fn simd_min(self, other: Self) -> Self { + Self(self.0.simd_min(other.0)) + } + + /// Per-lane maximum. + #[inline(always)] + pub fn simd_max(self, other: Self) -> Self { + Self(self.0.simd_max(other.0)) + } + + // ── Saturating arithmetic ───────────────────────────────────── + + /// Per-lane signed saturating add. Results clamp to `[i8::MIN, i8::MAX]`. + #[inline(always)] + pub fn saturating_add(self, other: Self) -> Self { + Self(self.0.saturating_add(other.0)) + } + + /// Per-lane signed saturating subtract. Results clamp to `[i8::MIN, i8::MAX]`. + #[inline(always)] + pub fn saturating_sub(self, other: Self) -> Self { + Self(self.0.saturating_sub(other.0)) + } + + // ── Comparisons → bitmask ───────────────────────────────────── + + /// Per-lane equality mask. Returns a 32-bit integer; bit i is set iff + /// `self[i] == other[i]`. Matches the 32-lane mask shape of the + /// AVX2 backend's `movemask_epi8` output. + #[inline(always)] + pub fn cmpeq_mask(self, other: Self) -> u32 { + self.0.simd_eq(other.0).to_bitmask() as u32 + } + + /// Per-lane signed greater-than mask. Returns a 32-bit integer; bit i is + /// set iff `self[i] > other[i]` (signed comparison). Matches the + /// `_mm256_movemask_epi8(_mm256_cmpgt_epi8(...))` shape of the AVX2 backend. + #[inline(always)] + pub fn cmpgt_mask(self, other: Self) -> u32 { + self.0.simd_gt(other.0).to_bitmask() as u32 + } +} + +impl PartialEq for I8x32 { + fn eq(&self, other: &Self) -> bool { + self.to_array() == other.to_array() + } +} diff --git a/src/simd_nightly/i_word_types.rs b/src/simd_nightly/i_word_types.rs new file mode 100644 index 00000000..d4e1c7ff --- /dev/null +++ b/src/simd_nightly/i_word_types.rs @@ -0,0 +1,430 @@ +//! I16x16 / I16x32 / I32x16 / I64x8 portable-simd wrappers — round-3-portable-simd agent #6. +#![cfg(feature = "nightly-simd")] + +use core::simd::cmp::{SimdOrd, SimdPartialEq, SimdPartialOrd}; +use core::simd::num::SimdInt; +use core::simd::{i16x16, i16x32, i32x16, i64x8}; + +// ════════════════════════════════════════════════════════════════════ +// I16x16 — 16-lane signed 16-bit integer +// ════════════════════════════════════════════════════════════════════ + +/// 16-lane `i16` SIMD vector backed by `core::simd::i16x16`. +/// +/// API mirrors `simd_avx512::I16x16`. Miri can execute every method. +#[derive(Copy, Clone, Debug)] +#[repr(transparent)] +pub struct I16x16(pub i16x16); + +impl I16x16 { + pub const LANES: usize = 16; + + #[inline(always)] + pub fn splat(v: i16) -> Self { + Self(i16x16::splat(v)) + } + + #[inline(always)] + pub fn from_array(arr: [i16; 16]) -> Self { + Self(i16x16::from_array(arr)) + } + + #[inline(always)] + pub fn from_slice(s: &[i16]) -> Self { + assert!(s.len() >= 16, "I16x16::from_slice needs ≥16 elements"); + Self(i16x16::from_slice(s)) + } + + #[inline(always)] + pub fn to_array(self) -> [i16; 16] { + self.0.to_array() + } + + #[inline(always)] + pub fn copy_to_slice(self, s: &mut [i16]) { + assert!(s.len() >= 16, "I16x16::copy_to_slice needs ≥16 elements"); + self.0.copy_to_slice(s); + } + + // ── Reductions ──────────────────────────────────────────────── + + /// Wrapping horizontal sum of all lanes. + #[inline(always)] + pub fn reduce_sum(self) -> i16 { + self.0.reduce_sum() + } + + #[inline(always)] + pub fn reduce_min(self) -> i16 { + self.0.reduce_min() + } + + #[inline(always)] + pub fn reduce_max(self) -> i16 { + self.0.reduce_max() + } + + // ── Lane-wise min / max ─────────────────────────────────────── + + #[inline(always)] + pub fn simd_min(self, other: Self) -> Self { + Self(self.0.simd_min(other.0)) + } + + #[inline(always)] + pub fn simd_max(self, other: Self) -> Self { + Self(self.0.simd_max(other.0)) + } + + // ── Saturating arithmetic ───────────────────────────────────── + + #[inline(always)] + pub fn saturating_add(self, other: Self) -> Self { + Self(self.0.saturating_add(other.0)) + } + + #[inline(always)] + pub fn saturating_sub(self, other: Self) -> Self { + Self(self.0.saturating_sub(other.0)) + } + + // ── Comparisons → bitmask ───────────────────────────────────── + + /// Per-lane equality. Bit i set iff `self[i] == other[i]`. + /// Returns a `u16` matching the 16-lane count. + #[inline(always)] + pub fn cmpeq_mask(self, other: Self) -> u16 { + self.0.simd_eq(other.0).to_bitmask() as u16 + } + + /// Per-lane signed greater-than. Bit i set iff `self[i] > other[i]`. + #[inline(always)] + pub fn cmpgt_mask(self, other: Self) -> u16 { + self.0.simd_gt(other.0).to_bitmask() as u16 + } +} + +impl PartialEq for I16x16 { + #[inline(always)] + fn eq(&self, other: &Self) -> bool { + self.to_array() == other.to_array() + } +} + +impl core::fmt::Display for I16x16 { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!(f, "I16x16({:?})", &self.to_array()[..]) + } +} + +// ════════════════════════════════════════════════════════════════════ +// I16x32 — 32-lane signed 16-bit integer +// ════════════════════════════════════════════════════════════════════ + +/// 32-lane `i16` SIMD vector backed by `core::simd::i16x32`. +/// +/// API mirrors `simd_avx512::I16x32`. Miri can execute every method. +#[derive(Copy, Clone, Debug)] +#[repr(transparent)] +pub struct I16x32(pub i16x32); + +impl I16x32 { + pub const LANES: usize = 32; + + #[inline(always)] + pub fn splat(v: i16) -> Self { + Self(i16x32::splat(v)) + } + + #[inline(always)] + pub fn from_array(arr: [i16; 32]) -> Self { + Self(i16x32::from_array(arr)) + } + + #[inline(always)] + pub fn from_slice(s: &[i16]) -> Self { + assert!(s.len() >= 32, "I16x32::from_slice needs ≥32 elements"); + Self(i16x32::from_slice(s)) + } + + #[inline(always)] + pub fn to_array(self) -> [i16; 32] { + self.0.to_array() + } + + #[inline(always)] + pub fn copy_to_slice(self, s: &mut [i16]) { + assert!(s.len() >= 32, "I16x32::copy_to_slice needs ≥32 elements"); + self.0.copy_to_slice(s); + } + + // ── Reductions ──────────────────────────────────────────────── + + /// Wrapping horizontal sum of all lanes. + #[inline(always)] + pub fn reduce_sum(self) -> i16 { + self.0.reduce_sum() + } + + #[inline(always)] + pub fn reduce_min(self) -> i16 { + self.0.reduce_min() + } + + #[inline(always)] + pub fn reduce_max(self) -> i16 { + self.0.reduce_max() + } + + // ── Lane-wise min / max ─────────────────────────────────────── + + #[inline(always)] + pub fn simd_min(self, other: Self) -> Self { + Self(self.0.simd_min(other.0)) + } + + #[inline(always)] + pub fn simd_max(self, other: Self) -> Self { + Self(self.0.simd_max(other.0)) + } + + // ── Saturating arithmetic ───────────────────────────────────── + + #[inline(always)] + pub fn saturating_add(self, other: Self) -> Self { + Self(self.0.saturating_add(other.0)) + } + + #[inline(always)] + pub fn saturating_sub(self, other: Self) -> Self { + Self(self.0.saturating_sub(other.0)) + } + + // ── Comparisons → bitmask ───────────────────────────────────── + + /// Per-lane equality. Bit i set iff `self[i] == other[i]`. + /// Returns a `u32` matching the 32-lane count. + #[inline(always)] + pub fn cmpeq_mask(self, other: Self) -> u32 { + self.0.simd_eq(other.0).to_bitmask() as u32 + } + + /// Per-lane signed greater-than. Bit i set iff `self[i] > other[i]`. + #[inline(always)] + pub fn cmpgt_mask(self, other: Self) -> u32 { + self.0.simd_gt(other.0).to_bitmask() as u32 + } +} + +impl PartialEq for I16x32 { + #[inline(always)] + fn eq(&self, other: &Self) -> bool { + self.to_array() == other.to_array() + } +} + +impl core::fmt::Display for I16x32 { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!(f, "I16x32({:?})", &self.to_array()[..]) + } +} + +// ════════════════════════════════════════════════════════════════════ +// I32x16 — 16-lane signed 32-bit integer +// ════════════════════════════════════════════════════════════════════ + +/// 16-lane `i32` SIMD vector backed by `core::simd::i32x16`. +/// +/// API mirrors `simd_avx512::I32x16`. Miri can execute every method. +#[derive(Copy, Clone, Debug)] +#[repr(transparent)] +pub struct I32x16(pub i32x16); + +impl I32x16 { + pub const LANES: usize = 16; + + #[inline(always)] + pub fn splat(v: i32) -> Self { + Self(i32x16::splat(v)) + } + + #[inline(always)] + pub fn from_array(arr: [i32; 16]) -> Self { + Self(i32x16::from_array(arr)) + } + + #[inline(always)] + pub fn from_slice(s: &[i32]) -> Self { + assert!(s.len() >= 16, "I32x16::from_slice needs ≥16 elements"); + Self(i32x16::from_slice(s)) + } + + #[inline(always)] + pub fn to_array(self) -> [i32; 16] { + self.0.to_array() + } + + #[inline(always)] + pub fn copy_to_slice(self, s: &mut [i32]) { + assert!(s.len() >= 16, "I32x16::copy_to_slice needs ≥16 elements"); + self.0.copy_to_slice(s); + } + + // ── Reductions ──────────────────────────────────────────────── + + /// Wrapping horizontal sum of all lanes. + #[inline(always)] + pub fn reduce_sum(self) -> i32 { + self.0.reduce_sum() + } + + #[inline(always)] + pub fn reduce_min(self) -> i32 { + self.0.reduce_min() + } + + #[inline(always)] + pub fn reduce_max(self) -> i32 { + self.0.reduce_max() + } + + // ── Lane-wise min / max ─────────────────────────────────────── + + #[inline(always)] + pub fn simd_min(self, other: Self) -> Self { + Self(self.0.simd_min(other.0)) + } + + #[inline(always)] + pub fn simd_max(self, other: Self) -> Self { + Self(self.0.simd_max(other.0)) + } + + // ── Comparisons → bitmask ───────────────────────────────────── + + /// Per-lane equality. Bit i set iff `self[i] == other[i]`. + /// Returns a `u16` matching the 16-lane count. + #[inline(always)] + pub fn cmpeq_mask(self, other: Self) -> u16 { + self.0.simd_eq(other.0).to_bitmask() as u16 + } + + /// Per-lane signed greater-than. Bit i set iff `self[i] > other[i]`. + #[inline(always)] + pub fn cmpgt_mask(self, other: Self) -> u16 { + self.0.simd_gt(other.0).to_bitmask() as u16 + } +} + +impl PartialEq for I32x16 { + #[inline(always)] + fn eq(&self, other: &Self) -> bool { + self.to_array() == other.to_array() + } +} + +impl core::fmt::Display for I32x16 { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!(f, "I32x16({:?})", &self.to_array()[..]) + } +} + +// ════════════════════════════════════════════════════════════════════ +// I64x8 — 8-lane signed 64-bit integer +// ════════════════════════════════════════════════════════════════════ + +/// 8-lane `i64` SIMD vector backed by `core::simd::i64x8`. +/// +/// API mirrors `simd_avx512::I64x8`. Miri can execute every method. +#[derive(Copy, Clone, Debug)] +#[repr(transparent)] +pub struct I64x8(pub i64x8); + +impl I64x8 { + pub const LANES: usize = 8; + + #[inline(always)] + pub fn splat(v: i64) -> Self { + Self(i64x8::splat(v)) + } + + #[inline(always)] + pub fn from_array(arr: [i64; 8]) -> Self { + Self(i64x8::from_array(arr)) + } + + #[inline(always)] + pub fn from_slice(s: &[i64]) -> Self { + assert!(s.len() >= 8, "I64x8::from_slice needs ≥8 elements"); + Self(i64x8::from_slice(s)) + } + + #[inline(always)] + pub fn to_array(self) -> [i64; 8] { + self.0.to_array() + } + + #[inline(always)] + pub fn copy_to_slice(self, s: &mut [i64]) { + assert!(s.len() >= 8, "I64x8::copy_to_slice needs ≥8 elements"); + self.0.copy_to_slice(s); + } + + // ── Reductions ──────────────────────────────────────────────── + + /// Wrapping horizontal sum of all lanes. + #[inline(always)] + pub fn reduce_sum(self) -> i64 { + self.0.reduce_sum() + } + + #[inline(always)] + pub fn reduce_min(self) -> i64 { + self.0.reduce_min() + } + + #[inline(always)] + pub fn reduce_max(self) -> i64 { + self.0.reduce_max() + } + + // ── Lane-wise min / max ─────────────────────────────────────── + + #[inline(always)] + pub fn simd_min(self, other: Self) -> Self { + Self(self.0.simd_min(other.0)) + } + + #[inline(always)] + pub fn simd_max(self, other: Self) -> Self { + Self(self.0.simd_max(other.0)) + } + + // ── Comparisons → bitmask ───────────────────────────────────── + + /// Per-lane equality. Bit i set iff `self[i] == other[i]`. + /// Returns a `u8` matching the 8-lane count. + #[inline(always)] + pub fn cmpeq_mask(self, other: Self) -> u8 { + self.0.simd_eq(other.0).to_bitmask() as u8 + } + + /// Per-lane signed greater-than. Bit i set iff `self[i] > other[i]`. + #[inline(always)] + pub fn cmpgt_mask(self, other: Self) -> u8 { + self.0.simd_gt(other.0).to_bitmask() as u8 + } +} + +impl PartialEq for I64x8 { + #[inline(always)] + fn eq(&self, other: &Self) -> bool { + self.to_array() == other.to_array() + } +} + +impl core::fmt::Display for I64x8 { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!(f, "I64x8({:?})", &self.to_array()[..]) + } +} diff --git a/src/simd_nightly/masks.rs b/src/simd_nightly/masks.rs new file mode 100644 index 00000000..03b76be9 --- /dev/null +++ b/src/simd_nightly/masks.rs @@ -0,0 +1,181 @@ +//! Comparison-result mask wrappers — round-3-portable-simd agent #7. +#![cfg(feature = "nightly-simd")] + +use super::{F32x16, F32x8, F64x4, F64x8}; +use core::simd::prelude::Select; +use core::simd::Mask; + +// ============================================================================ +// F32Mask16 — 16-lane mask for F32x16 comparisons +// ============================================================================ + +/// 16-lane mask wrapping `core::simd::Mask`. +/// +/// Mirrors `simd_avx512::F32Mask16` so consumer code compiles unchanged +/// under `--features nightly-simd`. +#[derive(Copy, Clone, Debug)] +#[repr(transparent)] +pub struct F32Mask16(pub Mask); + +impl F32Mask16 { + /// Convert to a 16-bit packed bitmask (matches the AVX-512 `__mmask16` + /// shape). Bit i is set iff lane i of the mask is true. + #[inline(always)] + pub fn to_bitmask(self) -> u16 { + self.0.to_bitmask() as u16 + } + + /// Reconstruct a mask from a 16-bit packed bitmask. + #[inline(always)] + pub fn from_bitmask(bits: u16) -> Self { + Self(Mask::::from_bitmask(bits as u64)) + } + + /// Per-lane select: returns `true_val[i]` where mask[i] is set, + /// else `false_val[i]`. + #[inline(always)] + pub fn select(self, true_val: F32x16, false_val: F32x16) -> F32x16 { + F32x16(self.0.select(true_val.0, false_val.0)) + } + + /// Returns `true` if every lane is set. + #[inline(always)] + pub fn all(self) -> bool { + self.0.all() + } + + /// Returns `true` if any lane is set. + #[inline(always)] + pub fn any(self) -> bool { + self.0.any() + } +} + +// ============================================================================ +// F32Mask8 — 8-lane mask for F32x8 comparisons +// ============================================================================ + +/// 8-lane mask wrapping `core::simd::Mask`. +#[derive(Copy, Clone, Debug)] +#[repr(transparent)] +pub struct F32Mask8(pub Mask); + +impl F32Mask8 { + /// Convert to an 8-bit packed bitmask. Bit i is set iff lane i is true. + #[inline(always)] + pub fn to_bitmask(self) -> u8 { + self.0.to_bitmask() as u8 + } + + /// Reconstruct a mask from an 8-bit packed bitmask. + #[inline(always)] + pub fn from_bitmask(bits: u8) -> Self { + Self(Mask::::from_bitmask(bits as u64)) + } + + /// Per-lane select: returns `true_val[i]` where mask[i] is set, + /// else `false_val[i]`. + #[inline(always)] + pub fn select(self, true_val: F32x8, false_val: F32x8) -> F32x8 { + F32x8(self.0.select(true_val.0, false_val.0)) + } + + /// Returns `true` if every lane is set. + #[inline(always)] + pub fn all(self) -> bool { + self.0.all() + } + + /// Returns `true` if any lane is set. + #[inline(always)] + pub fn any(self) -> bool { + self.0.any() + } +} + +// ============================================================================ +// F64Mask8 — 8-lane mask for F64x8 comparisons +// ============================================================================ + +/// 8-lane mask wrapping `core::simd::Mask`. +/// +/// Mirrors `simd_avx512::F64Mask8` so consumer code compiles unchanged +/// under `--features nightly-simd`. +#[derive(Copy, Clone, Debug)] +#[repr(transparent)] +pub struct F64Mask8(pub Mask); + +impl F64Mask8 { + /// Convert to an 8-bit packed bitmask. Bit i is set iff lane i is true. + #[inline(always)] + pub fn to_bitmask(self) -> u8 { + self.0.to_bitmask() as u8 + } + + /// Reconstruct a mask from an 8-bit packed bitmask. + #[inline(always)] + pub fn from_bitmask(bits: u8) -> Self { + Self(Mask::::from_bitmask(bits as u64)) + } + + /// Per-lane select: returns `true_val[i]` where mask[i] is set, + /// else `false_val[i]`. + #[inline(always)] + pub fn select(self, true_val: F64x8, false_val: F64x8) -> F64x8 { + F64x8(self.0.select(true_val.0, false_val.0)) + } + + /// Returns `true` if every lane is set. + #[inline(always)] + pub fn all(self) -> bool { + self.0.all() + } + + /// Returns `true` if any lane is set. + #[inline(always)] + pub fn any(self) -> bool { + self.0.any() + } +} + +// ============================================================================ +// F64Mask4 — 4-lane mask for F64x4 comparisons +// ============================================================================ + +/// 4-lane mask wrapping `core::simd::Mask`. +#[derive(Copy, Clone, Debug)] +#[repr(transparent)] +pub struct F64Mask4(pub Mask); + +impl F64Mask4 { + /// Convert to a 4-bit packed bitmask. Bit i is set iff lane i is true. + #[inline(always)] + pub fn to_bitmask(self) -> u8 { + self.0.to_bitmask() as u8 + } + + /// Reconstruct a mask from a 4-bit packed bitmask. + #[inline(always)] + pub fn from_bitmask(bits: u8) -> Self { + Self(Mask::::from_bitmask(bits as u64)) + } + + /// Per-lane select: returns `true_val[i]` where mask[i] is set, + /// else `false_val[i]`. + #[inline(always)] + pub fn select(self, true_val: F64x4, false_val: F64x4) -> F64x4 { + F64x4(self.0.select(true_val.0, false_val.0)) + } + + /// Returns `true` if every lane is set. + #[inline(always)] + pub fn all(self) -> bool { + self.0.all() + } + + /// Returns `true` if any lane is set. + #[inline(always)] + pub fn any(self) -> bool { + self.0.any() + } +} diff --git a/src/simd_nightly/mod.rs b/src/simd_nightly/mod.rs new file mode 100644 index 00000000..f0f902ac --- /dev/null +++ b/src/simd_nightly/mod.rs @@ -0,0 +1,45 @@ +//! Portable-SIMD polyfill backend — full 30-type coverage (NIGHTLY ONLY). +//! +//! Wraps `core::simd::*` so miri can execute the polyfill paths. +//! Intrinsics backends (`simd_avx512.rs` / `simd_avx2.rs`) are opaque +//! to miri; `core::simd` is not. With `--features nightly-simd`, consumer +//! code using `ndarray::simd_nightly::*` becomes miri-checkable. +//! +//! Gated entirely behind `#[cfg(feature = "nightly-simd")]`. Requires +//! `cargo +nightly` because the file pulls in `#![feature(portable_simd)]` +//! at the crate root (see `lib.rs`). +//! +//! # Coverage +//! +//! 30 types across 12 sub-modules, populated by the round-3-portable-simd +//! fleet. The module aggregator (this file) re-exports the public surface +//! flat so consumers write `use ndarray::simd_nightly::F32x16` rather +//! than `use ndarray::simd_nightly::f32_types::F32x16`. + +#![cfg(feature = "nightly-simd")] + +pub mod bf16_types; +pub mod exotic_methods; +pub mod f16_types; +pub mod f32_types; +pub mod f64_types; +pub mod i8_types; +pub mod i_word_types; +pub mod masks; +pub mod ops; +pub mod u8_types; +pub mod u_word_types; + +#[cfg(test)] +mod tests; + +// Flat re-exports — consumer surface matches `crate::simd::*` shape. +pub use bf16_types::{BF16x16, BF16x8}; +pub use f16_types::F16x16; +pub use f32_types::{F32x16, F32x8}; +pub use f64_types::{F64x4, F64x8}; +pub use i8_types::{I8x32, I8x64}; +pub use i_word_types::{I16x16, I16x32, I32x16, I64x8}; +pub use masks::{F32Mask16, F32Mask8, F64Mask4, F64Mask8}; +pub use u8_types::{U8x32, U8x64}; +pub use u_word_types::{U16x32, U32x16, U32x8, U64x4, U64x8}; diff --git a/src/simd_nightly/ops.rs b/src/simd_nightly/ops.rs new file mode 100644 index 00000000..6a5cdb5d --- /dev/null +++ b/src/simd_nightly/ops.rs @@ -0,0 +1,268 @@ +//! Operator impls for all simd_nightly types — round-3-portable-simd agent #10. +//! +//! Two main macros: +//! - `impl_fp_ops!` — Add/Sub/Mul/Div/Neg + assign variants for float types. +//! - `impl_int_ops!` — Add/Sub/BitAnd/BitOr/BitXor + assign variants for int types. +//! +//! A separate `impl_default!` macro handles `Default` for types whose own +//! module has NOT already derived/implemented it (e.g. F32x16/F32x8 already +//! implement `Default` in f32_types.rs; those are excluded here). +#![cfg(feature = "nightly-simd")] + +// ════════════════════════════════════════════════════════════════════ +// impl_fp_ops — float arithmetic + Neg + assign variants +// ════════════════════════════════════════════════════════════════════ + +/// Implement `Add`, `Sub`, `Mul`, `Div`, `Neg`, and their `*Assign` +/// counterparts for a newtype wrapper whose inner field is `.0`. +/// +/// The wrapped type must itself implement the corresponding `core::ops` +/// traits (which all `core::simd` vector types do). +macro_rules! impl_fp_ops { + ($name:ty) => { + impl core::ops::Add for $name { + type Output = Self; + #[inline(always)] + fn add(self, rhs: Self) -> Self { + Self(self.0 + rhs.0) + } + } + + impl core::ops::Sub for $name { + type Output = Self; + #[inline(always)] + fn sub(self, rhs: Self) -> Self { + Self(self.0 - rhs.0) + } + } + + impl core::ops::Mul for $name { + type Output = Self; + #[inline(always)] + fn mul(self, rhs: Self) -> Self { + Self(self.0 * rhs.0) + } + } + + impl core::ops::Div for $name { + type Output = Self; + #[inline(always)] + fn div(self, rhs: Self) -> Self { + Self(self.0 / rhs.0) + } + } + + impl core::ops::Neg for $name { + type Output = Self; + #[inline(always)] + fn neg(self) -> Self { + Self(-self.0) + } + } + + impl core::ops::AddAssign for $name { + #[inline(always)] + fn add_assign(&mut self, rhs: Self) { + self.0 = self.0 + rhs.0; + } + } + + impl core::ops::SubAssign for $name { + #[inline(always)] + fn sub_assign(&mut self, rhs: Self) { + self.0 = self.0 - rhs.0; + } + } + + impl core::ops::MulAssign for $name { + #[inline(always)] + fn mul_assign(&mut self, rhs: Self) { + self.0 = self.0 * rhs.0; + } + } + + impl core::ops::DivAssign for $name { + #[inline(always)] + fn div_assign(&mut self, rhs: Self) { + self.0 = self.0 / rhs.0; + } + } + }; +} + +// ════════════════════════════════════════════════════════════════════ +// impl_int_ops — integer Add/Sub/BitAnd/BitOr/BitXor + assign variants +// (wrapping semantics come from core::simd automatically) +// ════════════════════════════════════════════════════════════════════ + +/// Implement `Add`, `Sub`, `BitAnd`, `BitOr`, `BitXor`, and their +/// `*Assign` counterparts for a newtype wrapper whose inner field is `.0`. +/// +/// `Neg` is NOT included here; for signed integer types use +/// `impl_int_neg!` separately. +macro_rules! impl_int_ops { + ($name:ty) => { + impl core::ops::Add for $name { + type Output = Self; + #[inline(always)] + fn add(self, rhs: Self) -> Self { + Self(self.0 + rhs.0) + } + } + + impl core::ops::Sub for $name { + type Output = Self; + #[inline(always)] + fn sub(self, rhs: Self) -> Self { + Self(self.0 - rhs.0) + } + } + + impl core::ops::BitAnd for $name { + type Output = Self; + #[inline(always)] + fn bitand(self, rhs: Self) -> Self { + Self(self.0 & rhs.0) + } + } + + impl core::ops::BitOr for $name { + type Output = Self; + #[inline(always)] + fn bitor(self, rhs: Self) -> Self { + Self(self.0 | rhs.0) + } + } + + impl core::ops::BitXor for $name { + type Output = Self; + #[inline(always)] + fn bitxor(self, rhs: Self) -> Self { + Self(self.0 ^ rhs.0) + } + } + + impl core::ops::AddAssign for $name { + #[inline(always)] + fn add_assign(&mut self, rhs: Self) { + self.0 = self.0 + rhs.0; + } + } + + impl core::ops::SubAssign for $name { + #[inline(always)] + fn sub_assign(&mut self, rhs: Self) { + self.0 = self.0 - rhs.0; + } + } + + impl core::ops::BitAndAssign for $name { + #[inline(always)] + fn bitand_assign(&mut self, rhs: Self) { + self.0 = self.0 & rhs.0; + } + } + + impl core::ops::BitOrAssign for $name { + #[inline(always)] + fn bitor_assign(&mut self, rhs: Self) { + self.0 = self.0 | rhs.0; + } + } + + impl core::ops::BitXorAssign for $name { + #[inline(always)] + fn bitxor_assign(&mut self, rhs: Self) { + self.0 = self.0 ^ rhs.0; + } + } + }; +} + +// ════════════════════════════════════════════════════════════════════ +// impl_int_neg — Neg for signed integer types (separate from int_ops +// so unsigned types don't accidentally get it) +// ════════════════════════════════════════════════════════════════════ + +macro_rules! impl_int_neg { + ($name:ty) => { + impl core::ops::Neg for $name { + type Output = Self; + #[inline(always)] + fn neg(self) -> Self { + Self(-self.0) + } + } + }; +} + +// ════════════════════════════════════════════════════════════════════ +// impl_default — Default (zero / false splat) for types that do NOT +// already provide a Default impl in their own module. +// +// F32x16 and F32x8 already have Default in f32_types.rs — excluded. +// ════════════════════════════════════════════════════════════════════ + +macro_rules! impl_default { + ($name:ty) => { + impl Default for $name { + #[inline(always)] + fn default() -> Self { + Self(Default::default()) + } + } + }; +} + +// ════════════════════════════════════════════════════════════════════ +// Float type invocations +// ════════════════════════════════════════════════════════════════════ + +// F32x16 and F32x8 — ops only; Default already in f32_types.rs. +impl_fp_ops!(super::f32_types::F32x16); +impl_fp_ops!(super::f32_types::F32x8); + +// F64x8 and F64x4 — ops + Default (not in f64_types.rs). +impl_fp_ops!(super::f64_types::F64x8); +impl_default!(super::f64_types::F64x8); +impl_fp_ops!(super::f64_types::F64x4); +impl_default!(super::f64_types::F64x4); + +// ════════════════════════════════════════════════════════════════════ +// Unsigned integer type invocations +// ════════════════════════════════════════════════════════════════════ + +// U8x32, U8x64 — ops only; Default already in u8_types.rs. +impl_int_ops!(super::u8_types::U8x32); +impl_int_ops!(super::u8_types::U8x64); + +// All u_word types have Default in u_word_types.rs — ops only. +impl_int_ops!(super::u_word_types::U16x32); +impl_int_ops!(super::u_word_types::U32x16); +impl_int_ops!(super::u_word_types::U32x8); +impl_int_ops!(super::u_word_types::U64x8); +impl_int_ops!(super::u_word_types::U64x4); + +// ════════════════════════════════════════════════════════════════════ +// Signed integer type invocations (int_ops + Neg + Default) +// ════════════════════════════════════════════════════════════════════ + +impl_int_ops!(super::i8_types::I8x32); +impl_int_neg!(super::i8_types::I8x32); +impl_default!(super::i8_types::I8x32); +impl_int_ops!(super::i8_types::I8x64); +impl_int_neg!(super::i8_types::I8x64); +impl_default!(super::i8_types::I8x64); + +impl_int_ops!(super::i_word_types::I16x16); +impl_int_neg!(super::i_word_types::I16x16); +impl_default!(super::i_word_types::I16x16); +impl_int_ops!(super::i_word_types::I16x32); +impl_int_neg!(super::i_word_types::I16x32); +impl_default!(super::i_word_types::I16x32); +impl_int_ops!(super::i_word_types::I32x16); +impl_int_neg!(super::i_word_types::I32x16); +impl_default!(super::i_word_types::I32x16); +impl_int_ops!(super::i_word_types::I64x8); +impl_int_neg!(super::i_word_types::I64x8); +impl_default!(super::i_word_types::I64x8); diff --git a/src/simd_nightly/tests.rs b/src/simd_nightly/tests.rs new file mode 100644 index 00000000..cddec229 --- /dev/null +++ b/src/simd_nightly/tests.rs @@ -0,0 +1,815 @@ +//! Round-3-portable-simd parity tests — agent #12. +//! +//! Every test verifies `simd_nightly::*` behaves identically to pure scalar +//! reference math. No comparison against `simd_avx512::*` (miri-opaque). +#![cfg(all(test, feature = "nightly-simd"))] + +use super::*; + +// ════════════════════════════════════════════════════════════════════════════ +// 1. Constructor roundtrip — splat/from_array/from_slice/copy_to_slice +// ════════════════════════════════════════════════════════════════════════════ + +#[test] +fn f32x16_splat_roundtrip() { + let v = F32x16::splat(3.14_f32); + let arr = v.to_array(); + assert!(arr.iter().all(|&x| (x - 3.14_f32).abs() < 1e-6)); +} + +#[test] +fn f32x16_from_array_identity() { + let src: [f32; 16] = core::array::from_fn(|i| i as f32 * 7.0); + assert_eq!(F32x16::from_array(src).to_array(), src); +} + +#[test] +fn f32x16_from_slice_copy_to_slice() { + let src: [f32; 16] = core::array::from_fn(|i| i as f32); + let mut dst = [0.0_f32; 16]; + F32x16::from_slice(&src).copy_to_slice(&mut dst); + assert_eq!(src, dst); +} + +#[test] +fn f32x8_splat_roundtrip() { + let v = F32x8::splat(-1.0_f32); + assert!(v.to_array().iter().all(|&x| x == -1.0_f32)); +} + +#[test] +fn f32x8_from_array_identity() { + let src: [f32; 8] = core::array::from_fn(|i| i as f32 * 3.0); + assert_eq!(F32x8::from_array(src).to_array(), src); +} + +#[test] +fn f32x8_from_slice_copy_to_slice() { + let src: [f32; 8] = core::array::from_fn(|i| (i * 5) as f32); + let mut dst = [0.0_f32; 8]; + F32x8::from_slice(&src).copy_to_slice(&mut dst); + assert_eq!(src, dst); +} + +#[test] +fn f64x8_splat_roundtrip() { + let v = F64x8::splat(2.718_f64); + assert!(v.to_array().iter().all(|&x| (x - 2.718_f64).abs() < 1e-12)); +} + +#[test] +fn f64x8_from_array_identity() { + let src: [f64; 8] = core::array::from_fn(|i| i as f64 * 7.0); + assert_eq!(F64x8::from_array(src).to_array(), src); +} + +#[test] +fn f64x8_from_slice_copy_to_slice() { + let src: [f64; 8] = core::array::from_fn(|i| i as f64); + let mut dst = [0.0_f64; 8]; + F64x8::from_slice(&src).copy_to_slice(&mut dst); + assert_eq!(src, dst); +} + +#[test] +fn f64x4_splat_roundtrip() { + let v = F64x4::splat(1.414_f64); + assert!(v.to_array().iter().all(|&x| (x - 1.414_f64).abs() < 1e-12)); +} + +#[test] +fn f64x4_from_array_identity() { + let src: [f64; 4] = core::array::from_fn(|i| (i + 1) as f64 * 2.5); + assert_eq!(F64x4::from_array(src).to_array(), src); +} + +#[test] +fn f64x4_from_slice_copy_to_slice() { + let src: [f64; 4] = [1.0, 2.0, 3.0, 4.0]; + let mut dst = [0.0_f64; 4]; + F64x4::from_slice(&src).copy_to_slice(&mut dst); + assert_eq!(src, dst); +} + +// ════════════════════════════════════════════════════════════════════════════ +// 2. Reduction parity — sum / min / max vs scalar fold +// ════════════════════════════════════════════════════════════════════════════ + +#[test] +fn f32x16_reduce_sum_parity() { + let src: [f32; 16] = core::array::from_fn(|i| (i as f32) * 7.0); + let scalar_sum: f32 = src.iter().copied().sum(); + let simd_sum = F32x16::from_array(src).reduce_sum(); + // Allow small FP rounding difference + assert!((simd_sum - scalar_sum).abs() < 1e-3, "sum mismatch: {simd_sum} vs {scalar_sum}"); +} + +#[test] +fn f32x16_reduce_min_parity() { + let src: [f32; 16] = core::array::from_fn(|i| (i as f32) * 7.0); + let scalar_min = src.iter().copied().fold(f32::INFINITY, f32::min); + assert_eq!(F32x16::from_array(src).reduce_min(), scalar_min); +} + +#[test] +fn f32x16_reduce_max_parity() { + let src: [f32; 16] = core::array::from_fn(|i| (i as f32) * 7.0); + let scalar_max = src.iter().copied().fold(f32::NEG_INFINITY, f32::max); + assert_eq!(F32x16::from_array(src).reduce_max(), scalar_max); +} + +#[test] +fn f32x8_reduce_sum_parity() { + let src: [f32; 8] = core::array::from_fn(|i| (i as f32) * 7.0); + let scalar_sum: f32 = src.iter().copied().sum(); + let simd_sum = F32x8::from_array(src).reduce_sum(); + assert!((simd_sum - scalar_sum).abs() < 1e-3); +} + +#[test] +fn f32x8_reduce_min_max_parity() { + let src: [f32; 8] = core::array::from_fn(|i| (i as f32) * 3.0); + let scalar_min = src.iter().copied().fold(f32::INFINITY, f32::min); + let scalar_max = src.iter().copied().fold(f32::NEG_INFINITY, f32::max); + let v = F32x8::from_array(src); + assert_eq!(v.reduce_min(), scalar_min); + assert_eq!(v.reduce_max(), scalar_max); +} + +#[test] +fn f64x8_reduce_sum_parity() { + let src: [f64; 8] = core::array::from_fn(|i| (i as f64) * 7.0); + let scalar_sum: f64 = src.iter().copied().sum(); + let simd_sum = F64x8::from_array(src).reduce_sum(); + assert!((simd_sum - scalar_sum).abs() < 1e-9); +} + +#[test] +fn f64x8_reduce_min_max_parity() { + let src: [f64; 8] = core::array::from_fn(|i| (i as f64) * 5.0); + let scalar_min = src.iter().copied().fold(f64::INFINITY, f64::min); + let scalar_max = src.iter().copied().fold(f64::NEG_INFINITY, f64::max); + let v = F64x8::from_array(src); + assert_eq!(v.reduce_min(), scalar_min); + assert_eq!(v.reduce_max(), scalar_max); +} + +#[test] +fn f64x4_reduce_sum_parity() { + let src: [f64; 4] = core::array::from_fn(|i| (i as f64) * 7.0); + let scalar_sum: f64 = src.iter().copied().sum(); + let simd_sum = F64x4::from_array(src).reduce_sum(); + assert!((simd_sum - scalar_sum).abs() < 1e-9); +} + +// ════════════════════════════════════════════════════════════════════════════ +// 3. Comparison mask parity — bitmask matches scalar per-lane predicate +// ════════════════════════════════════════════════════════════════════════════ + +#[test] +fn f32x16_cmpeq_mask_parity() { + let a: [f32; 16] = core::array::from_fn(|i| (i % 4) as f32); + let b: [f32; 16] = core::array::from_fn(|i| (i % 2) as f32); + let mask = F32x16::from_array(a) + .simd_eq(F32x16::from_array(b)) + .to_bitmask(); + let mut expected: u16 = 0; + for i in 0..16 { + if a[i] == b[i] { + expected |= 1 << i; + } + } + assert_eq!(mask, expected, "cmpeq mask mismatch"); +} + +#[test] +fn f32x16_cmpgt_mask_parity() { + let a: [f32; 16] = core::array::from_fn(|i| i as f32); + let b = F32x16::splat(7.5); + let mask = F32x16::from_array(a).simd_gt(b).to_bitmask(); + let mut expected: u16 = 0; + for i in 0..16 { + if a[i] > 7.5 { + expected |= 1 << i; + } + } + assert_eq!(mask, expected, "cmpgt mask mismatch"); +} + +#[test] +fn f32x8_cmpeq_mask_parity() { + let a: [f32; 8] = core::array::from_fn(|i| (i % 2) as f32); + let b: [f32; 8] = [0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0]; + let mask = F32x8::from_array(a) + .simd_eq(F32x8::from_array(b)) + .to_bitmask(); + let mut expected: u8 = 0; + for i in 0..8 { + if a[i] == b[i] { + expected |= 1 << i; + } + } + assert_eq!(mask, expected); +} + +#[test] +fn f64x8_cmpeq_mask_parity() { + let a: [f64; 8] = core::array::from_fn(|i| (i % 3) as f64); + let b: [f64; 8] = core::array::from_fn(|i| (i % 2) as f64); + let mask = F64x8::from_array(a) + .simd_eq(F64x8::from_array(b)) + .to_bitmask(); + let mut expected: u8 = 0; + for i in 0..8 { + if a[i] == b[i] { + expected |= 1 << i; + } + } + assert_eq!(mask, expected); +} + +#[test] +fn f64x4_cmpgt_mask_parity() { + let a: [f64; 4] = [5.0, 1.0, 7.0, 3.0]; + let b = F64x4::splat(4.0); + let mask = F64x4::from_array(a).simd_gt(b).to_bitmask(); + // Lanes 0 (5>4) and 2 (7>4) → bits 0 and 2 → 0b0101 = 5 + assert_eq!(mask, 0b0101_u8); +} + +#[test] +fn u8x64_cmpeq_mask_parity() { + let a: [u8; 64] = core::array::from_fn(|i| (i & 7) as u8); + let b: [u8; 64] = core::array::from_fn(|i| (i & 6) as u8); + let mask = u8_types::U8x64::from_array(a).cmpeq_mask(u8_types::U8x64::from_array(b)); + for i in 0..64 { + let expected_bit = if a[i] == b[i] { 1u64 } else { 0u64 }; + assert_eq!((mask >> i) & 1, expected_bit, "lane {i}"); + } +} + +#[test] +fn u8x64_cmpgt_mask_parity() { + let a: [u8; 64] = core::array::from_fn(|i| i as u8); + let b: [u8; 64] = core::array::from_fn(|i| (63 - i) as u8); + let mask = u8_types::U8x64::from_array(a).cmpgt_mask(u8_types::U8x64::from_array(b)); + for i in 0..64 { + let expected_bit = if a[i] > b[i] { 1u64 } else { 0u64 }; + assert_eq!((mask >> i) & 1, expected_bit, "lane {i}: {} > {}?", a[i], b[i]); + } +} + +#[test] +fn u8x32_cmpeq_mask_parity() { + let a: [u8; 32] = core::array::from_fn(|i| (i & 3) as u8); + let b: [u8; 32] = core::array::from_fn(|i| (i & 2) as u8); + let mask = u8_types::U8x32::from_array(a).cmpeq_mask(u8_types::U8x32::from_array(b)); + let mut expected: u32 = 0; + for i in 0..32 { + if a[i] == b[i] { + expected |= 1 << i; + } + } + assert_eq!(mask, expected); +} + +// ════════════════════════════════════════════════════════════════════════════ +// 4. Saturating arithmetic — clamps at max/min for unsigned integer types +// ════════════════════════════════════════════════════════════════════════════ + +#[test] +fn u8x64_saturating_add_clamps_at_max() { + let a = u8_types::U8x64::splat(200); + let b = u8_types::U8x64::splat(100); + let r = a.saturating_add(b); + // scalar: min(200 + 100, 255) = 255 + assert!(r.to_array().iter().all(|&x| x == 255), "expected 255, got {:?}", r.to_array()[0]); +} + +#[test] +fn u8x64_saturating_sub_clamps_at_zero() { + let a = u8_types::U8x64::splat(10); + let b = u8_types::U8x64::splat(50); + let r = a.saturating_sub(b); + // scalar: max(10 - 50, 0) = 0 + assert!(r.to_array().iter().all(|&x| x == 0)); +} + +#[test] +fn u8x32_saturating_add_clamps_at_max() { + let a = u8_types::U8x32::splat(255); + let b = u8_types::U8x32::splat(1); + let r = a.saturating_add(b); + assert!(r.to_array().iter().all(|&x| x == 255)); +} + +#[test] +fn u8x32_saturating_sub_clamps_at_zero() { + let a = u8_types::U8x32::splat(0); + let b = u8_types::U8x32::splat(100); + let r = a.saturating_sub(b); + assert!(r.to_array().iter().all(|&x| x == 0)); +} + +#[test] +fn u16x32_saturating_add_clamps_at_max() { + let a = u_word_types::U16x32::splat(u16::MAX - 1); + let b = u_word_types::U16x32::splat(10); + let r = a.saturating_add(b); + assert!(r.to_array().iter().all(|&x| x == u16::MAX)); +} + +#[test] +fn u16x32_saturating_sub_clamps_at_zero() { + let a = u_word_types::U16x32::splat(5); + let b = u_word_types::U16x32::splat(10); + let r = a.saturating_sub(b); + assert!(r.to_array().iter().all(|&x| x == 0)); +} + +// ════════════════════════════════════════════════════════════════════════════ +// 5. FMA bit-exact — mul_add(0.5, 2.0, 1.0) = 2.0 per lane +// ════════════════════════════════════════════════════════════════════════════ + +#[test] +fn f32x16_fma_exact() { + // 0.5 * 2.0 + 1.0 = 2.0 (exact in IEEE-754) + let result = F32x16::splat(0.5).mul_add(F32x16::splat(2.0), F32x16::splat(1.0)); + let arr = result.to_array(); + assert!(arr.iter().all(|&v| v == 2.0_f32), "expected all 2.0, got {:?}", &arr[..4]); +} + +#[test] +fn f32x8_fma_exact() { + let result = F32x8::splat(0.5).mul_add(F32x8::splat(2.0), F32x8::splat(1.0)); + assert!(result.to_array().iter().all(|&v| v == 2.0_f32)); +} + +#[test] +fn f64x8_fma_exact() { + // 0.5 * 2.0 + 1.0 = 2.0 (exact) + let result = F64x8::splat(0.5).mul_add(F64x8::splat(2.0), F64x8::splat(1.0)); + assert!(result.to_array().iter().all(|&v| v == 2.0_f64)); +} + +#[test] +fn f64x4_fma_exact() { + let result = F64x4::splat(0.5).mul_add(F64x4::splat(2.0), F64x4::splat(1.0)); + assert!(result.to_array().iter().all(|&v| v == 2.0_f64)); +} + +#[test] +fn f32x16_fma_scalar_parity() { + // Verify mul_add matches scalar f32::mul_add lane-by-lane + let a: [f32; 16] = core::array::from_fn(|i| (i as f32) * 0.1); + let b: [f32; 16] = core::array::from_fn(|i| (i as f32) * 0.3); + let c: [f32; 16] = core::array::from_fn(|i| i as f32); + let simd_result = F32x16::from_array(a).mul_add(F32x16::from_array(b), F32x16::from_array(c)); + for (i, (&r, (&ai, (&bi, &ci)))) in simd_result + .to_array() + .iter() + .zip(a.iter().zip(b.iter().zip(c.iter()))) + .enumerate() + { + let scalar = ai.mul_add(bi, ci); + assert!((r - scalar).abs() < 1e-5, "lane {i}: simd={r} scalar={scalar}"); + } +} + +// ════════════════════════════════════════════════════════════════════════════ +// 6. BF16 / F16 roundtrip — from_f32 → to_f32 within expected precision +// ════════════════════════════════════════════════════════════════════════════ + +#[test] +fn bf16x16_from_f32_to_f32_lossy_within_bounds() { + // BF16 truncates the low 16 mantissa bits; error ≤ 2^(e-7) where e = exponent + let inputs: [f32; 16] = core::array::from_fn(|i| (i as f32 + 1.0) * 0.5); + let v = bf16_types::BF16x16::from_f32_truncate(inputs); + let out = v.to_f32_lossy(); + for (i, (&orig, &back)) in inputs.iter().zip(out.iter()).enumerate() { + // BF16 truncation error is at most |orig| * 2^-7 + let max_err = orig.abs() * (2.0_f32.powi(-7)) + f32::EPSILON; + assert!((orig - back).abs() <= max_err, "lane {i}: orig={orig} back={back} err={}", (orig - back).abs()); + } +} + +#[test] +fn bf16x8_from_f32_to_f32_lossy_within_bounds() { + let inputs: [f32; 8] = core::array::from_fn(|i| i as f32 * 2.0 + 1.0); + let v = bf16_types::BF16x8::from_f32_truncate(inputs); + let out = v.to_f32_lossy(); + for (i, (&orig, &back)) in inputs.iter().zip(out.iter()).enumerate() { + let max_err = orig.abs() * (2.0_f32.powi(-7)) + f32::EPSILON; + assert!((orig - back).abs() <= max_err, "lane {i}: orig={orig} back={back}"); + } +} + +#[test] +fn bf16x16_bit_pattern_stability() { + // Round-trip through bit patterns must be identity + let bits: [u16; 16] = core::array::from_fn(|i| (i as u16) * 0x100 + 0x3F00); + let v = bf16_types::BF16x16::from_array(bits); + assert_eq!(v.to_array(), bits); +} + +#[test] +fn bf16x8_bit_pattern_stability() { + let bits: [u16; 8] = core::array::from_fn(|i| (i as u16) * 0x200 + 0x3F00); + let v = bf16_types::BF16x8::from_array(bits); + assert_eq!(v.to_array(), bits); +} + +#[test] +fn f16x16_from_f32_to_f32_roundtrip_within_1ulp() { + // Representable f16 values (0..=15 as f32) should round-trip exactly. + let inputs: [f32; 16] = core::array::from_fn(|i| i as f32); + let v = f16_types::F16x16::from_f32_array(inputs); + let out = v.to_f32_array(); + for (i, (&orig, &back)) in inputs.iter().zip(out.iter()).enumerate() { + // These integers are exactly representable in f16 (range 0..2048) + assert!((orig - back).abs() <= 0.5, "lane {i}: orig={orig} back={back}"); + } +} + +#[test] +fn f16x16_bit_pattern_stability() { + // from_array → to_array must be identity for raw bit patterns + let raw: [u16; 16] = core::array::from_fn(|i| 0x3C00_u16 + i as u16); // ~1.0 range + let v = f16_types::F16x16::from_array(raw); + assert_eq!(v.to_array(), raw); +} + +// ════════════════════════════════════════════════════════════════════════════ +// 7. Mask select — F32Mask16 / F32Mask8 / F64Mask8 / F64Mask4 +// ════════════════════════════════════════════════════════════════════════════ + +#[test] +fn f32mask16_select_all_true() { + let mask = masks::F32Mask16::from_bitmask(0xFFFF); + let t = F32x16::splat(1.0); + let f = F32x16::splat(2.0); + assert!(mask.select(t, f).to_array().iter().all(|&x| x == 1.0)); +} + +#[test] +fn f32mask16_select_all_false() { + let mask = masks::F32Mask16::from_bitmask(0x0000); + let t = F32x16::splat(1.0); + let f = F32x16::splat(2.0); + assert!(mask.select(t, f).to_array().iter().all(|&x| x == 2.0)); +} + +#[test] +fn f32mask16_select_alternating() { + // bits 0,2,4,...,14 set → even lanes from true_val, odd from false_val + let mask = masks::F32Mask16::from_bitmask(0x5555); // 0101010101010101 + let t = F32x16::splat(1.0); + let f = F32x16::splat(2.0); + let result = mask.select(t, f).to_array(); + for i in 0..16 { + if i % 2 == 0 { + assert_eq!(result[i], 1.0, "even lane {i} should be 1.0"); + } else { + assert_eq!(result[i], 2.0, "odd lane {i} should be 2.0"); + } + } +} + +#[test] +fn f32mask16_select_scalar_parity() { + // Build mask from comparison, then verify select matches scalar ternary + let a: [f32; 16] = core::array::from_fn(|i| i as f32); + let threshold = F32x16::splat(7.5); + let va = F32x16::from_array(a); + let mask = va.simd_lt(threshold); + let result = mask + .select(F32x16::splat(100.0), F32x16::splat(200.0)) + .to_array(); + for i in 0..16 { + let expected = if a[i] < 7.5 { 100.0_f32 } else { 200.0_f32 }; + assert_eq!(result[i], expected, "lane {i}"); + } +} + +#[test] +fn f32mask8_select_scalar_parity() { + let a: [f32; 8] = core::array::from_fn(|i| i as f32); + let threshold = F32x8::splat(3.5); + let va = F32x8::from_array(a); + let mask = va.simd_gt(threshold); + let result = mask + .select(F32x8::splat(10.0), F32x8::splat(20.0)) + .to_array(); + for i in 0..8 { + let expected = if a[i] > 3.5 { 10.0_f32 } else { 20.0_f32 }; + assert_eq!(result[i], expected, "lane {i}"); + } +} + +#[test] +fn f64mask8_select_scalar_parity() { + let a: [f64; 8] = core::array::from_fn(|i| i as f64); + let threshold = F64x8::splat(4.5); + let va = F64x8::from_array(a); + let mask = va.simd_ge(threshold); + let result = mask.select(F64x8::splat(1.0), F64x8::splat(0.0)).to_array(); + for i in 0..8 { + let expected = if a[i] >= 4.5 { 1.0_f64 } else { 0.0_f64 }; + assert_eq!(result[i], expected, "lane {i}"); + } +} + +#[test] +fn f64mask4_select_scalar_parity() { + let a: [f64; 4] = [1.0, 5.0, 3.0, 7.0]; + let threshold = F64x4::splat(4.0); + let va = F64x4::from_array(a); + let mask = va.simd_gt(threshold); + let result = mask + .select(F64x4::splat(99.0), F64x4::splat(0.0)) + .to_array(); + // Lanes 1 (5>4) and 3 (7>4) should be 99.0 + assert_eq!(result, [0.0, 99.0, 0.0, 99.0]); +} + +// ════════════════════════════════════════════════════════════════════════════ +// 8. Exotic methods — permute_bytes reverse / nibble_popcount_lut +// ════════════════════════════════════════════════════════════════════════════ + +#[test] +fn u8x64_permute_bytes_reverse_identity() { + // idx[i] = 63 - i reverses the vector + let v = u8_types::U8x64::from_array(core::array::from_fn(|i| i as u8)); + let idx = u8_types::U8x64::from_array(core::array::from_fn(|i| (63 - i) as u8)); + let r = v.permute_bytes(idx); + for i in 0..64 { + assert_eq!(r.to_array()[i], (63 - i) as u8, "lane {i}"); + } +} + +#[test] +fn u8x32_permute_bytes_reverse_identity() { + // idx[i] = 31 - i reverses the vector + let v = u8_types::U8x32::from_array(core::array::from_fn(|i| i as u8)); + let idx = u8_types::U8x32::from_array(core::array::from_fn(|i| (31 - i) as u8)); + let r = v.permute_bytes(idx); + for i in 0..32 { + assert_eq!(r.to_array()[i], (31 - i) as u8, "lane {i}"); + } +} + +#[test] +fn u8x64_nibble_popcount_lut_vs_scalar_count_ones() { + let lut = u8_types::U8x64::nibble_popcount_lut(); + let lut_arr = lut.to_array(); + // First 16 entries map nibble value → popcount + for nibble in 0u8..16 { + let simd_count = lut_arr[nibble as usize]; + let scalar_count = nibble.count_ones() as u8; + assert_eq!(simd_count, scalar_count, "nibble {nibble:#04x}"); + } +} + +#[test] +fn u8x32_nibble_popcount_lut_vs_scalar_count_ones() { + let lut = u8_types::U8x32::nibble_popcount_lut(); + let lut_arr = lut.to_array(); + // First 16 entries + for nibble in 0u8..16 { + let simd_count = lut_arr[nibble as usize]; + let scalar_count = nibble.count_ones() as u8; + assert_eq!(simd_count, scalar_count, "nibble {nibble:#04x}"); + } +} + +#[test] +fn u8x64_shuffle_bytes_popcount_parity() { + // Use shuffle_bytes to look up popcount of each nibble in test data. + // Input byte b → low nibble → lut[b & 0xF] gives popcount of low nibble. + let lut = u8_types::U8x64::nibble_popcount_lut(); + let data: [u8; 64] = core::array::from_fn(|i| (i % 16) as u8); + let idx = u8_types::U8x64::from_array(data); + let result = lut.shuffle_bytes(idx).to_array(); + for i in 0..64 { + let nibble = data[i]; + let expected = nibble.count_ones() as u8; + assert_eq!(result[i], expected, "data[{i}]={nibble} expected popcount {expected}"); + } +} + +// ════════════════════════════════════════════════════════════════════════════ +// 9. U-word types: constructor roundtrip + reduction parity +// ════════════════════════════════════════════════════════════════════════════ + +#[test] +fn u64x8_splat_and_reduce_parity() { + let src: [u64; 8] = core::array::from_fn(|i| (i as u64) * 7); + let v = u_word_types::U64x8::from_array(src); + let simd_sum = v.reduce_sum(); + let scalar_sum: u64 = src.iter().sum(); + assert_eq!(simd_sum, scalar_sum); + assert_eq!(v.reduce_min(), *src.iter().min().unwrap()); + assert_eq!(v.reduce_max(), *src.iter().max().unwrap()); +} + +#[test] +fn u64x4_splat_and_reduce_parity() { + let src: [u64; 4] = core::array::from_fn(|i| (i as u64) * 11); + let v = u_word_types::U64x4::from_array(src); + let simd_sum = v.reduce_sum(); + let scalar_sum: u64 = src.iter().sum(); + assert_eq!(simd_sum, scalar_sum); + assert_eq!(v.reduce_min(), *src.iter().min().unwrap()); + assert_eq!(v.reduce_max(), *src.iter().max().unwrap()); +} + +#[test] +fn u32x16_splat_and_reduce_parity() { + let src: [u32; 16] = core::array::from_fn(|i| (i as u32) * 7); + let v = u_word_types::U32x16::from_array(src); + let simd_sum = v.reduce_sum(); + let scalar_sum: u32 = src.iter().sum(); + assert_eq!(simd_sum, scalar_sum); + assert_eq!(v.reduce_min(), *src.iter().min().unwrap()); + assert_eq!(v.reduce_max(), *src.iter().max().unwrap()); +} + +#[test] +fn u32x8_splat_and_reduce_parity() { + let src: [u32; 8] = core::array::from_fn(|i| (i as u32) * 5); + let v = u_word_types::U32x8::from_array(src); + let simd_sum = v.reduce_sum(); + let scalar_sum: u32 = src.iter().sum(); + assert_eq!(simd_sum, scalar_sum); + assert_eq!(v.reduce_min(), *src.iter().min().unwrap()); + assert_eq!(v.reduce_max(), *src.iter().max().unwrap()); +} + +#[test] +fn u16x32_splat_and_reduce_parity() { + let src: [u16; 32] = core::array::from_fn(|i| (i as u16) * 3); + let v = u_word_types::U16x32::from_array(src); + assert_eq!(v.reduce_min(), *src.iter().min().unwrap()); + assert_eq!(v.reduce_max(), *src.iter().max().unwrap()); +} + +// ════════════════════════════════════════════════════════════════════════════ +// 10. F32x16 math ops — sqrt / abs / floor / round vs scalar +// ════════════════════════════════════════════════════════════════════════════ + +#[test] +fn f32x16_sqrt_parity() { + let src: [f32; 16] = core::array::from_fn(|i| (i as f32) * 2.0 + 1.0); + let result = F32x16::from_array(src).sqrt().to_array(); + for (i, (&r, &s)) in result.iter().zip(src.iter()).enumerate() { + let expected = s.sqrt(); + assert!((r - expected).abs() < 1e-5, "lane {i}: sqrt({s}) = {r} expected {expected}"); + } +} + +#[test] +fn f32x16_abs_parity() { + let src: [f32; 16] = core::array::from_fn(|i| if i % 2 == 0 { i as f32 } else { -(i as f32) }); + let result = F32x16::from_array(src).abs().to_array(); + for (i, (&r, &s)) in result.iter().zip(src.iter()).enumerate() { + assert_eq!(r, s.abs(), "lane {i}"); + } +} + +#[test] +fn f32x16_floor_parity() { + let src: [f32; 16] = core::array::from_fn(|i| i as f32 * 0.7 - 3.5); + let result = F32x16::from_array(src).floor().to_array(); + for (i, (&r, &s)) in result.iter().zip(src.iter()).enumerate() { + assert_eq!(r, s.floor(), "lane {i}"); + } +} + +#[test] +fn f32x16_round_parity() { + let src: [f32; 16] = [0.5, 1.5, 2.5, 3.5, 4.5, 5.5, -0.5, -1.5, 0.4, 1.6, 2.3, 3.7, -0.4, -1.6, 100.0, -100.0]; + let result = F32x16::from_array(src).round().to_array(); + for (i, (&r, &s)) in result.iter().zip(src.iter()).enumerate() { + assert_eq!(r, s.round(), "lane {i}"); + } +} + +// ════════════════════════════════════════════════════════════════════════════ +// 11. F32x16 to_bits / from_bits roundtrip +// ════════════════════════════════════════════════════════════════════════════ + +#[test] +fn f32x16_bits_roundtrip() { + let src: [f32; 16] = core::array::from_fn(|i| (i as f32) * 1.5); + let v = F32x16::from_array(src); + let bits = v.to_bits(); + let back = F32x16::from_bits(bits); + assert_eq!(v.to_array(), back.to_array()); +} + +#[test] +fn f32x16_bits_match_scalar() { + let src: [f32; 16] = core::array::from_fn(|i| i as f32); + let bits = F32x16::from_array(src).to_bits().to_array(); + for (i, (&b, &s)) in bits.iter().zip(src.iter()).enumerate() { + assert_eq!(b, s.to_bits(), "lane {i}"); + } +} + +// ════════════════════════════════════════════════════════════════════════════ +// 12. Operator impls — Add/Sub/Mul/Div/BitAnd/BitOr/BitXor scalar parity +// ════════════════════════════════════════════════════════════════════════════ + +#[test] +fn f32x16_add_sub_mul_div_parity() { + let a: [f32; 16] = core::array::from_fn(|i| (i as f32) + 1.0); + let b: [f32; 16] = core::array::from_fn(|i| (i as f32) * 0.5 + 0.1); + let va = F32x16::from_array(a); + let vb = F32x16::from_array(b); + let add = (va + vb).to_array(); + let sub = (va - vb).to_array(); + let mul = (va * vb).to_array(); + let div = (va / vb).to_array(); + for i in 0..16 { + assert!((add[i] - (a[i] + b[i])).abs() < 1e-5, "add lane {i}"); + assert!((sub[i] - (a[i] - b[i])).abs() < 1e-5, "sub lane {i}"); + assert!((mul[i] - (a[i] * b[i])).abs() < 1e-5, "mul lane {i}"); + assert!((div[i] - (a[i] / b[i])).abs() < 1e-4, "div lane {i}"); + } +} + +#[test] +fn u8x64_bitwise_ops_parity() { + let a: [u8; 64] = core::array::from_fn(|i| i as u8); + let b: [u8; 64] = core::array::from_fn(|i| (255 - i) as u8); + let va = u8_types::U8x64::from_array(a); + let vb = u8_types::U8x64::from_array(b); + let and_r = (va & vb).to_array(); + let or_r = (va | vb).to_array(); + let xor_r = (va ^ vb).to_array(); + for i in 0..64 { + assert_eq!(and_r[i], a[i] & b[i], "AND lane {i}"); + assert_eq!(or_r[i], a[i] | b[i], "OR lane {i}"); + assert_eq!(xor_r[i], a[i] ^ b[i], "XOR lane {i}"); + } +} + +// ════════════════════════════════════════════════════════════════════════════ +// 13. Mask bitmask roundtrip — from_bitmask(to_bitmask(m)) == m +// ════════════════════════════════════════════════════════════════════════════ + +#[test] +fn f32mask16_bitmask_roundtrip() { + let orig: u16 = 0b1010_1010_0101_0101; + let mask = masks::F32Mask16::from_bitmask(orig); + assert_eq!(mask.to_bitmask(), orig); +} + +#[test] +fn f32mask8_bitmask_roundtrip() { + let orig: u8 = 0b1100_0011; + let mask = masks::F32Mask8::from_bitmask(orig); + assert_eq!(mask.to_bitmask(), orig); +} + +#[test] +fn f64mask8_bitmask_roundtrip() { + let orig: u8 = 0b0101_1010; + let mask = masks::F64Mask8::from_bitmask(orig); + assert_eq!(mask.to_bitmask(), orig); +} + +#[test] +fn f64mask4_bitmask_roundtrip() { + let orig: u8 = 0b1101; // only low 4 bits matter + let mask = masks::F64Mask4::from_bitmask(orig); + assert_eq!(mask.to_bitmask() & 0xF, orig & 0xF); +} + +// ════════════════════════════════════════════════════════════════════════════ +// 14. simd_clamp parity — lanes clamped exactly as scalar +// ════════════════════════════════════════════════════════════════════════════ + +#[test] +fn f32x16_simd_clamp_parity() { + let src: [f32; 16] = core::array::from_fn(|i| i as f32 - 5.0); // -5..10 + let lo = F32x16::splat(0.0); + let hi = F32x16::splat(8.0); + let result = F32x16::from_array(src).simd_clamp(lo, hi).to_array(); + for (i, (&r, &s)) in result.iter().zip(src.iter()).enumerate() { + let expected = s.clamp(0.0, 8.0); + assert_eq!(r, expected, "lane {i}"); + } +} + +#[test] +fn f64x8_simd_clamp_parity() { + let src: [f64; 8] = core::array::from_fn(|i| i as f64 - 3.0); // -3..4 + let lo = F64x8::splat(-1.0); + let hi = F64x8::splat(3.0); + let result = F64x8::from_array(src).simd_clamp(lo, hi).to_array(); + for (i, (&r, &s)) in result.iter().zip(src.iter()).enumerate() { + let expected = s.clamp(-1.0, 3.0); + assert_eq!(r, expected, "lane {i}"); + } +} diff --git a/src/simd_nightly/u8_types.rs b/src/simd_nightly/u8_types.rs new file mode 100644 index 00000000..01d6c5cf --- /dev/null +++ b/src/simd_nightly/u8_types.rs @@ -0,0 +1,1040 @@ +//! U8x32 / U8x64 portable-simd wrappers — round-3-portable-simd agent #3. +#![cfg(feature = "nightly-simd")] + +use core::simd::cmp::{SimdOrd, SimdPartialEq, SimdPartialOrd}; +use core::simd::num::SimdUint; +use core::simd::{u8x32 as core_u8x32, u8x64 as core_u8x64, Simd}; + +// ════════════════════════════════════════════════════════════════════ +// U8x64 — 64-lane unsigned byte (rasterizer / palette / NBT width) +// ════════════════════════════════════════════════════════════════════ + +/// 64-lane `u8` SIMD vector backed by `core::simd::u8x64`. +/// +/// API mirrors `simd_avx512::U8x64` so consumer code is identical under +/// both the intrinsics and the portable-simd backend. Miri can execute +/// every method below. +/// +/// # Examples +/// ```rust +/// # #[cfg(feature = "nightly-simd")] { +/// use ndarray::simd_nightly::U8x64; +/// let a = U8x64::splat(10); +/// let b = U8x64::splat(20); +/// assert_eq!(a.simd_min(b).to_array()[0], 10); +/// # } +/// ``` +#[derive(Copy, Clone, Debug)] +#[repr(transparent)] +pub struct U8x64(pub core_u8x64); + +impl U8x64 { + /// Number of `u8` lanes. + pub const LANES: usize = 64; + + // ── Constructors ──────────────────────────────────────────────── + + /// Broadcast a single byte to all 64 lanes. + /// + /// # Examples + /// ```rust + /// # #[cfg(feature = "nightly-simd")] { + /// use ndarray::simd_nightly::U8x64; + /// assert_eq!(U8x64::splat(7).to_array()[0], 7); + /// # } + /// ``` + #[inline(always)] + pub fn splat(v: u8) -> Self { + Self(core_u8x64::splat(v)) + } + + /// Unaligned load 64 bytes from a slice. Panics if `s.len() < 64`. + /// + /// # Examples + /// ```rust + /// # #[cfg(feature = "nightly-simd")] { + /// use ndarray::simd_nightly::U8x64; + /// let data = [1u8; 64]; + /// let v = U8x64::from_slice(&data); + /// assert_eq!(v.to_array()[0], 1); + /// # } + /// ``` + #[inline(always)] + pub fn from_slice(s: &[u8]) -> Self { + assert!(s.len() >= 64, "U8x64::from_slice needs ≥64 bytes"); + Self(core_u8x64::from_slice(s)) + } + + /// Load 64 bytes from a fixed-size array. + /// + /// # Examples + /// ```rust + /// # #[cfg(feature = "nightly-simd")] { + /// use ndarray::simd_nightly::U8x64; + /// let arr = [2u8; 64]; + /// assert_eq!(U8x64::from_array(arr).to_array(), arr); + /// # } + /// ``` + #[inline(always)] + pub fn from_array(arr: [u8; 64]) -> Self { + Self(core_u8x64::from_array(arr)) + } + + /// Store all 64 bytes into a `[u8; 64]` array. + /// + /// # Examples + /// ```rust + /// # #[cfg(feature = "nightly-simd")] { + /// use ndarray::simd_nightly::U8x64; + /// let v = U8x64::splat(3); + /// assert_eq!(v.to_array()[63], 3); + /// # } + /// ``` + #[inline(always)] + pub fn to_array(self) -> [u8; 64] { + self.0.to_array() + } + + /// Copy all 64 bytes into a mutable slice. Panics if `s.len() < 64`. + /// + /// # Examples + /// ```rust + /// # #[cfg(feature = "nightly-simd")] { + /// use ndarray::simd_nightly::U8x64; + /// let mut buf = [0u8; 64]; + /// U8x64::splat(9).copy_to_slice(&mut buf); + /// assert_eq!(buf[0], 9); + /// # } + /// ``` + #[inline(always)] + pub fn copy_to_slice(self, s: &mut [u8]) { + assert!(s.len() >= 64, "U8x64::copy_to_slice needs ≥64 bytes"); + self.0.copy_to_slice(s); + } + + // ── Reductions ──────────────────────────────────────────────── + + /// Wrapping sum of all 64 lanes → `u8` (wraps at 256). + /// + /// # Examples + /// ```rust + /// # #[cfg(feature = "nightly-simd")] { + /// use ndarray::simd_nightly::U8x64; + /// assert_eq!(U8x64::splat(1).reduce_sum(), 64u8.wrapping_mul(1)); + /// # } + /// ``` + #[inline(always)] + pub fn reduce_sum(self) -> u8 { + self.0.reduce_sum() + } + + /// Unsigned minimum across all 64 lanes. + /// + /// # Examples + /// ```rust + /// # #[cfg(feature = "nightly-simd")] { + /// use ndarray::simd_nightly::U8x64; + /// let mut arr = [10u8; 64]; + /// arr[5] = 2; + /// assert_eq!(U8x64::from_array(arr).reduce_min(), 2); + /// # } + /// ``` + #[inline(always)] + pub fn reduce_min(self) -> u8 { + self.0.reduce_min() + } + + /// Unsigned maximum across all 64 lanes. + /// + /// # Examples + /// ```rust + /// # #[cfg(feature = "nightly-simd")] { + /// use ndarray::simd_nightly::U8x64; + /// let mut arr = [10u8; 64]; + /// arr[5] = 200; + /// assert_eq!(U8x64::from_array(arr).reduce_max(), 200); + /// # } + /// ``` + #[inline(always)] + pub fn reduce_max(self) -> u8 { + self.0.reduce_max() + } + + /// Horizontal byte sum as `u64` — does NOT wrap at 256. + /// + /// Promotes each byte lane to `u16`, then reduces. Range: 0..=64×255=16_320. + /// + /// # Examples + /// ```rust + /// # #[cfg(feature = "nightly-simd")] { + /// use ndarray::simd_nightly::U8x64; + /// assert_eq!(U8x64::splat(1).sum_bytes_u64(), 64); + /// assert_eq!(U8x64::splat(255).sum_bytes_u64(), 64 * 255); + /// # } + /// ``` + #[inline(always)] + pub fn sum_bytes_u64(self) -> u64 { + // Promote u8 × 64 → u16 × 64 to avoid 8-bit wrapping, then reduce. + let v16: Simd = self.0.cast::(); + v16.reduce_sum() as u64 + } + + // ── Lane-wise min / max ─────────────────────────────────────── + + /// Per-lane unsigned min. + /// + /// # Examples + /// ```rust + /// # #[cfg(feature = "nightly-simd")] { + /// use ndarray::simd_nightly::U8x64; + /// let a = U8x64::splat(100); + /// let b = U8x64::splat(50); + /// assert_eq!(a.simd_min(b).to_array()[0], 50); + /// # } + /// ``` + #[inline(always)] + pub fn simd_min(self, other: Self) -> Self { + Self(self.0.simd_min(other.0)) + } + + /// Per-lane unsigned max. + /// + /// # Examples + /// ```rust + /// # #[cfg(feature = "nightly-simd")] { + /// use ndarray::simd_nightly::U8x64; + /// let a = U8x64::splat(100); + /// let b = U8x64::splat(50); + /// assert_eq!(a.simd_max(b).to_array()[0], 100); + /// # } + /// ``` + #[inline(always)] + pub fn simd_max(self, other: Self) -> Self { + Self(self.0.simd_max(other.0)) + } + + // ── Saturating arithmetic ──────────────────────────────────────── + + /// Per-lane saturating unsigned add: `min(a + b, 255)`. + /// + /// # Examples + /// ```rust + /// # #[cfg(feature = "nightly-simd")] { + /// use ndarray::simd_nightly::U8x64; + /// let a = U8x64::splat(200); + /// let b = U8x64::splat(100); + /// assert_eq!(a.saturating_add(b).to_array()[0], 255); + /// # } + /// ``` + #[inline(always)] + pub fn saturating_add(self, other: Self) -> Self { + Self(self.0.saturating_add(other.0)) + } + + /// Per-lane saturating unsigned sub: `max(a - b, 0)`. + /// + /// # Examples + /// ```rust + /// # #[cfg(feature = "nightly-simd")] { + /// use ndarray::simd_nightly::U8x64; + /// let a = U8x64::splat(10); + /// let b = U8x64::splat(20); + /// assert_eq!(a.saturating_sub(b).to_array()[0], 0); + /// # } + /// ``` + #[inline(always)] + pub fn saturating_sub(self, other: Self) -> Self { + Self(self.0.saturating_sub(other.0)) + } + + /// Per-lane unsigned rounded average: `(a + b + 1) >> 1`. + /// + /// `core::simd` has no native `avg_epu8`; computed via u16 promotion + /// to avoid overflow. LLVM may lower to `vpavgb` on AVX-512 builds. + /// + /// # Examples + /// ```rust + /// # #[cfg(feature = "nightly-simd")] { + /// use ndarray::simd_nightly::U8x64; + /// let a = U8x64::splat(10); + /// let b = U8x64::splat(11); + /// // (10+11+1)/2 = 11 + /// assert_eq!(a.pairwise_avg(b).to_array()[0], 11); + /// # } + /// ``` + #[inline(always)] + pub fn pairwise_avg(self, other: Self) -> Self { + let a16: Simd = self.0.cast::(); + let b16: Simd = other.0.cast::(); + let avg = (a16 + b16 + Simd::splat(1u16)) >> Simd::splat(1u16); + Self(avg.cast::()) + } + + // ── 16-bit-lane shifts (nibble pack/unpack) ───────────────────── + + /// Right shift each 16-bit lane by `imm` bits. + /// + /// Operates on 16-bit lanes (same semantics as `_mm512_srli_epi16`). + /// Promotes to `u16 × 32`, shifts, truncates back to `u8 × 64`. + /// + /// # Examples + /// ```rust + /// # #[cfg(feature = "nightly-simd")] { + /// use ndarray::simd_nightly::U8x64; + /// // Byte pattern 0x10 in high nibble, shift right 4 → 0x01 in low byte + /// let v = U8x64::splat(0x10); + /// assert_eq!(v.shr_epi16(4).to_array()[0], 0x01); + /// # } + /// ``` + #[inline(always)] + pub fn shr_epi16(self, imm: u32) -> Self { + // Reinterpret u8×64 as u16×32, apply shift, reinterpret back. + let bytes = self.to_array(); + // SAFETY: [u8; 64] has same size as [u16; 32]. + let mut words: [u16; 32] = unsafe { core::mem::transmute(bytes) }; + for w in words.iter_mut() { + *w >>= imm; + } + Self::from_array(unsafe { core::mem::transmute(words) }) + } + + /// Left shift each 16-bit lane by `imm` bits. + /// + /// Operates on 16-bit lanes (same semantics as `_mm512_slli_epi16`). + /// + /// # Examples + /// ```rust + /// # #[cfg(feature = "nightly-simd")] { + /// use ndarray::simd_nightly::U8x64; + /// let v = U8x64::splat(0x01); + /// // shl by 4 → high nibble: 0x10 in high byte of u16 lane (low byte of pair) + /// assert_eq!(v.shl_epi16(4).to_array()[0], 0x10); + /// # } + /// ``` + #[inline(always)] + pub fn shl_epi16(self, imm: u32) -> Self { + let bytes = self.to_array(); + let mut words: [u16; 32] = unsafe { core::mem::transmute(bytes) }; + for w in words.iter_mut() { + *w <<= imm; + } + Self::from_array(unsafe { core::mem::transmute(words) }) + } + + // ── Comparison → 64-bit bitmask ─────────────────────────────── + + /// Per-lane equality → 64-bit bitmask (bit `i` set iff `self[i] == other[i]`). + /// + /// Matches `simd_avx512::U8x64::cmpeq_mask` shape. + /// + /// # Examples + /// ```rust + /// # #[cfg(feature = "nightly-simd")] { + /// use ndarray::simd_nightly::U8x64; + /// let a = U8x64::splat(5); + /// let b = U8x64::splat(5); + /// assert_eq!(a.cmpeq_mask(b), u64::MAX); + /// # } + /// ``` + #[inline(always)] + pub fn cmpeq_mask(self, other: Self) -> u64 { + self.0.simd_eq(other.0).to_bitmask() + } + + /// Per-lane unsigned greater-than → 64-bit bitmask. + /// + /// # Examples + /// ```rust + /// # #[cfg(feature = "nightly-simd")] { + /// use ndarray::simd_nightly::U8x64; + /// let a = U8x64::splat(10); + /// let b = U8x64::splat(5); + /// assert_eq!(a.cmpgt_mask(b), u64::MAX); + /// # } + /// ``` + #[inline(always)] + pub fn cmpgt_mask(self, other: Self) -> u64 { + self.0.simd_gt(other.0).to_bitmask() + } + + /// Extract MSB of each lane as a 64-bit mask. + /// + /// Bit `i` is set iff `self[i] >= 128`. Equivalent to `_mm512_movepi8_mask`. + /// + /// # Examples + /// ```rust + /// # #[cfg(feature = "nightly-simd")] { + /// use ndarray::simd_nightly::U8x64; + /// assert_eq!(U8x64::splat(128).movemask(), u64::MAX); + /// assert_eq!(U8x64::splat(127).movemask(), 0); + /// # } + /// ``` + #[inline(always)] + pub fn movemask(self) -> u64 { + // MSB set ⇔ value > 127 in unsigned comparison. + self.0.simd_gt(core_u8x64::splat(0x7F)).to_bitmask() + } + + // ── Nibble popcount LUT ──────────────────────────────────────── + + /// Build the nibble-popcount lookup table (replicated across all 64 bytes). + /// + /// Entry `i` (for i in 0..16) = `popcount(i)`. Intended for use with the + /// shuffle-based Mula SIMD popcount algorithm in `exotic_methods.rs`. + /// + /// # Examples + /// ```rust + /// # #[cfg(feature = "nightly-simd")] { + /// use ndarray::simd_nightly::U8x64; + /// let lut = U8x64::nibble_popcount_lut(); + /// // entry 0 → 0, entry 1 → 1, entry 3 → 2, entry 15 → 4 + /// assert_eq!(lut.to_array()[0], 0); + /// assert_eq!(lut.to_array()[3], 2); + /// assert_eq!(lut.to_array()[15], 4); + /// # } + /// ``` + #[inline(always)] + pub fn nibble_popcount_lut() -> Self { + // LUT for nibbles 0..=15, replicated 4× to fill 64 bytes. + Self::from_array([ + 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, 0, 1, 1, 2, + 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, + ]) + } +} + +impl Default for U8x64 { + #[inline(always)] + fn default() -> Self { + Self::splat(0) + } +} + +// ════════════════════════════════════════════════════════════════════ +// U8x32 — 32-lane unsigned byte (AVX2 width) +// ════════════════════════════════════════════════════════════════════ + +/// 32-lane `u8` SIMD vector backed by `core::simd::u8x32`. +/// +/// API mirrors `simd_avx2::U8x32` so consumer code is identical under +/// both the intrinsics and the portable-simd backend. Miri can execute +/// every method below. +/// +/// # Examples +/// ```rust +/// # #[cfg(feature = "nightly-simd")] { +/// use ndarray::simd_nightly::U8x32; +/// let a = U8x32::splat(10); +/// let b = U8x32::splat(20); +/// assert_eq!(a.simd_min(b).to_array()[0], 10); +/// # } +/// ``` +#[derive(Copy, Clone, Debug)] +#[repr(transparent)] +pub struct U8x32(pub core_u8x32); + +impl U8x32 { + /// Number of `u8` lanes. + pub const LANES: usize = 32; + + // ── Constructors ──────────────────────────────────────────────── + + /// Broadcast a single byte to all 32 lanes. + /// + /// # Examples + /// ```rust + /// # #[cfg(feature = "nightly-simd")] { + /// use ndarray::simd_nightly::U8x32; + /// assert_eq!(U8x32::splat(7).to_array()[0], 7); + /// # } + /// ``` + #[inline(always)] + pub fn splat(v: u8) -> Self { + Self(core_u8x32::splat(v)) + } + + /// Unaligned load 32 bytes from a slice. Panics if `s.len() < 32`. + /// + /// # Examples + /// ```rust + /// # #[cfg(feature = "nightly-simd")] { + /// use ndarray::simd_nightly::U8x32; + /// let data = [1u8; 32]; + /// let v = U8x32::from_slice(&data); + /// assert_eq!(v.to_array()[0], 1); + /// # } + /// ``` + #[inline(always)] + pub fn from_slice(s: &[u8]) -> Self { + assert!(s.len() >= 32, "U8x32::from_slice needs ≥32 bytes"); + Self(core_u8x32::from_slice(s)) + } + + /// Load 32 bytes from a fixed-size array. + /// + /// # Examples + /// ```rust + /// # #[cfg(feature = "nightly-simd")] { + /// use ndarray::simd_nightly::U8x32; + /// let arr = [2u8; 32]; + /// assert_eq!(U8x32::from_array(arr).to_array(), arr); + /// # } + /// ``` + #[inline(always)] + pub fn from_array(arr: [u8; 32]) -> Self { + Self(core_u8x32::from_array(arr)) + } + + /// Store all 32 bytes into a `[u8; 32]` array. + /// + /// # Examples + /// ```rust + /// # #[cfg(feature = "nightly-simd")] { + /// use ndarray::simd_nightly::U8x32; + /// let v = U8x32::splat(3); + /// assert_eq!(v.to_array()[31], 3); + /// # } + /// ``` + #[inline(always)] + pub fn to_array(self) -> [u8; 32] { + self.0.to_array() + } + + /// Copy all 32 bytes into a mutable slice. Panics if `s.len() < 32`. + /// + /// # Examples + /// ```rust + /// # #[cfg(feature = "nightly-simd")] { + /// use ndarray::simd_nightly::U8x32; + /// let mut buf = [0u8; 32]; + /// U8x32::splat(9).copy_to_slice(&mut buf); + /// assert_eq!(buf[0], 9); + /// # } + /// ``` + #[inline(always)] + pub fn copy_to_slice(self, s: &mut [u8]) { + assert!(s.len() >= 32, "U8x32::copy_to_slice needs ≥32 bytes"); + self.0.copy_to_slice(s); + } + + // ── Reductions ──────────────────────────────────────────────── + + /// Wrapping sum of all 32 lanes → `u8` (wraps at 256). + /// + /// # Examples + /// ```rust + /// # #[cfg(feature = "nightly-simd")] { + /// use ndarray::simd_nightly::U8x32; + /// assert_eq!(U8x32::splat(2).reduce_sum(), 64u8.wrapping_mul(1)); // 32*2=64 + /// # } + /// ``` + #[inline(always)] + pub fn reduce_sum(self) -> u8 { + self.0.reduce_sum() + } + + /// Unsigned minimum across all 32 lanes. + /// + /// # Examples + /// ```rust + /// # #[cfg(feature = "nightly-simd")] { + /// use ndarray::simd_nightly::U8x32; + /// let mut arr = [10u8; 32]; + /// arr[5] = 2; + /// assert_eq!(U8x32::from_array(arr).reduce_min(), 2); + /// # } + /// ``` + #[inline(always)] + pub fn reduce_min(self) -> u8 { + self.0.reduce_min() + } + + /// Unsigned maximum across all 32 lanes. + /// + /// # Examples + /// ```rust + /// # #[cfg(feature = "nightly-simd")] { + /// use ndarray::simd_nightly::U8x32; + /// let mut arr = [10u8; 32]; + /// arr[5] = 200; + /// assert_eq!(U8x32::from_array(arr).reduce_max(), 200); + /// # } + /// ``` + #[inline(always)] + pub fn reduce_max(self) -> u8 { + self.0.reduce_max() + } + + /// Horizontal byte sum as `u64` — does NOT wrap at 256. + /// + /// Promotes each byte lane to `u16`, then reduces. Range: 0..=32×255=8_160. + /// + /// # Examples + /// ```rust + /// # #[cfg(feature = "nightly-simd")] { + /// use ndarray::simd_nightly::U8x32; + /// assert_eq!(U8x32::splat(1).sum_bytes_u64(), 32); + /// assert_eq!(U8x32::splat(255).sum_bytes_u64(), 32 * 255); + /// # } + /// ``` + #[inline(always)] + pub fn sum_bytes_u64(self) -> u64 { + let v16: Simd = self.0.cast::(); + v16.reduce_sum() as u64 + } + + // ── Lane-wise min / max ─────────────────────────────────────── + + /// Per-lane unsigned min. + /// + /// # Examples + /// ```rust + /// # #[cfg(feature = "nightly-simd")] { + /// use ndarray::simd_nightly::U8x32; + /// let a = U8x32::splat(100); + /// let b = U8x32::splat(50); + /// assert_eq!(a.simd_min(b).to_array()[0], 50); + /// # } + /// ``` + #[inline(always)] + pub fn simd_min(self, other: Self) -> Self { + Self(self.0.simd_min(other.0)) + } + + /// Per-lane unsigned max. + /// + /// # Examples + /// ```rust + /// # #[cfg(feature = "nightly-simd")] { + /// use ndarray::simd_nightly::U8x32; + /// let a = U8x32::splat(100); + /// let b = U8x32::splat(50); + /// assert_eq!(a.simd_max(b).to_array()[0], 100); + /// # } + /// ``` + #[inline(always)] + pub fn simd_max(self, other: Self) -> Self { + Self(self.0.simd_max(other.0)) + } + + // ── Saturating arithmetic ──────────────────────────────────────── + + /// Per-lane saturating unsigned add: `min(a + b, 255)`. + /// + /// # Examples + /// ```rust + /// # #[cfg(feature = "nightly-simd")] { + /// use ndarray::simd_nightly::U8x32; + /// let a = U8x32::splat(200); + /// let b = U8x32::splat(100); + /// assert_eq!(a.saturating_add(b).to_array()[0], 255); + /// # } + /// ``` + #[inline(always)] + pub fn saturating_add(self, other: Self) -> Self { + Self(self.0.saturating_add(other.0)) + } + + /// Per-lane saturating unsigned sub: `max(a - b, 0)`. + /// + /// # Examples + /// ```rust + /// # #[cfg(feature = "nightly-simd")] { + /// use ndarray::simd_nightly::U8x32; + /// let a = U8x32::splat(10); + /// let b = U8x32::splat(20); + /// assert_eq!(a.saturating_sub(b).to_array()[0], 0); + /// # } + /// ``` + #[inline(always)] + pub fn saturating_sub(self, other: Self) -> Self { + Self(self.0.saturating_sub(other.0)) + } + + /// Per-lane unsigned rounded average: `(a + b + 1) >> 1`. + /// + /// `core::simd` has no native `avg_epu8`; computed via u16 promotion + /// to avoid overflow. LLVM may lower to `vpavgb` on AVX2 builds. + /// + /// # Examples + /// ```rust + /// # #[cfg(feature = "nightly-simd")] { + /// use ndarray::simd_nightly::U8x32; + /// let a = U8x32::splat(10); + /// let b = U8x32::splat(11); + /// // (10+11+1)/2 = 11 + /// assert_eq!(a.pairwise_avg(b).to_array()[0], 11); + /// # } + /// ``` + #[inline(always)] + pub fn pairwise_avg(self, other: Self) -> Self { + let a16: Simd = self.0.cast::(); + let b16: Simd = other.0.cast::(); + let avg = (a16 + b16 + Simd::splat(1u16)) >> Simd::splat(1u16); + Self(avg.cast::()) + } + + // ── 16-bit-lane shifts (nibble pack/unpack) ───────────────────── + + /// Right shift each 16-bit lane by `imm` bits. + /// + /// Operates on 16-bit lanes (same semantics as `_mm256_srli_epi16`). + /// Reinterprets `u8 × 32` as `u16 × 16`, shifts, reinterprets back. + /// + /// # Examples + /// ```rust + /// # #[cfg(feature = "nightly-simd")] { + /// use ndarray::simd_nightly::U8x32; + /// let v = U8x32::splat(0x10); + /// assert_eq!(v.shr_epi16(4).to_array()[0], 0x01); + /// # } + /// ``` + #[inline(always)] + pub fn shr_epi16(self, imm: u32) -> Self { + let bytes = self.to_array(); + // SAFETY: [u8; 32] and [u16; 16] have identical size and alignment req ≤ 2. + let mut words: [u16; 16] = unsafe { core::mem::transmute(bytes) }; + for w in words.iter_mut() { + *w >>= imm; + } + Self::from_array(unsafe { core::mem::transmute(words) }) + } + + /// Left shift each 16-bit lane by `imm` bits. + /// + /// Operates on 16-bit lanes (same semantics as `_mm256_slli_epi16`). + /// + /// # Examples + /// ```rust + /// # #[cfg(feature = "nightly-simd")] { + /// use ndarray::simd_nightly::U8x32; + /// let v = U8x32::splat(0x01); + /// assert_eq!(v.shl_epi16(4).to_array()[0], 0x10); + /// # } + /// ``` + #[inline(always)] + pub fn shl_epi16(self, imm: u32) -> Self { + let bytes = self.to_array(); + let mut words: [u16; 16] = unsafe { core::mem::transmute(bytes) }; + for w in words.iter_mut() { + *w <<= imm; + } + Self::from_array(unsafe { core::mem::transmute(words) }) + } + + // ── Comparison → 32-bit bitmask ─────────────────────────────── + + /// Per-lane equality → 32-bit bitmask (bit `i` set iff `self[i] == other[i]`). + /// + /// Matches `simd_avx2::U8x32::cmpeq_mask` shape. `core::simd::to_bitmask()` + /// returns `u64` for all lane widths ≤ 64; the upper 32 bits are always zero. + /// + /// # Examples + /// ```rust + /// # #[cfg(feature = "nightly-simd")] { + /// use ndarray::simd_nightly::U8x32; + /// let a = U8x32::splat(5); + /// let b = U8x32::splat(5); + /// assert_eq!(a.cmpeq_mask(b), u32::MAX); + /// # } + /// ``` + #[inline(always)] + pub fn cmpeq_mask(self, other: Self) -> u32 { + self.0.simd_eq(other.0).to_bitmask() as u32 + } + + /// Per-lane unsigned greater-than → 32-bit bitmask. + /// + /// # Examples + /// ```rust + /// # #[cfg(feature = "nightly-simd")] { + /// use ndarray::simd_nightly::U8x32; + /// let a = U8x32::splat(10); + /// let b = U8x32::splat(5); + /// assert_eq!(a.cmpgt_mask(b), u32::MAX); + /// # } + /// ``` + #[inline(always)] + pub fn cmpgt_mask(self, other: Self) -> u32 { + self.0.simd_gt(other.0).to_bitmask() as u32 + } + + /// Extract MSB of each lane as a 32-bit mask. + /// + /// Bit `i` is set iff `self[i] >= 128`. Equivalent to `_mm256_movemask_epi8` + /// but limited to MSB of each byte. + /// + /// # Examples + /// ```rust + /// # #[cfg(feature = "nightly-simd")] { + /// use ndarray::simd_nightly::U8x32; + /// assert_eq!(U8x32::splat(128).movemask(), u32::MAX); + /// assert_eq!(U8x32::splat(127).movemask(), 0); + /// # } + /// ``` + #[inline(always)] + pub fn movemask(self) -> u32 { + self.0.simd_gt(core_u8x32::splat(0x7F)).to_bitmask() as u32 + } + + // ── Nibble popcount LUT ──────────────────────────────────────── + + /// Build the nibble-popcount lookup table (replicated across both 16-byte halves). + /// + /// Entry `i` (for i in 0..16) = `popcount(i)`. Intended for use with the + /// shuffle-based Mula SIMD popcount algorithm in `exotic_methods.rs`. + /// + /// # Examples + /// ```rust + /// # #[cfg(feature = "nightly-simd")] { + /// use ndarray::simd_nightly::U8x32; + /// let lut = U8x32::nibble_popcount_lut(); + /// assert_eq!(lut.to_array()[0], 0); + /// assert_eq!(lut.to_array()[3], 2); + /// assert_eq!(lut.to_array()[15], 4); + /// # } + /// ``` + #[inline(always)] + pub fn nibble_popcount_lut() -> Self { + // LUT replicated twice to fill 32 bytes (2 × 128-bit halves). + Self::from_array([ + 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, + ]) + } +} + +impl Default for U8x32 { + #[inline(always)] + fn default() -> Self { + Self::splat(0) + } +} + +// ════════════════════════════════════════════════════════════════════ +// Tests +// ════════════════════════════════════════════════════════════════════ + +#[cfg(test)] +mod tests { + use super::*; + + // ── U8x64 tests ───────────────────────────────────────────────── + + #[test] + fn u8x64_splat_to_array() { + let v = U8x64::splat(42); + assert!(v.to_array().iter().all(|&b| b == 42)); + } + + #[test] + fn u8x64_from_slice_roundtrip() { + let src: [u8; 64] = core::array::from_fn(|i| i as u8); + let mut dst = [0u8; 64]; + U8x64::from_slice(&src).copy_to_slice(&mut dst); + assert_eq!(src, dst); + } + + #[test] + fn u8x64_reduce_min_max() { + let mut arr = [100u8; 64]; + arr[0] = 5; + arr[63] = 200; + let v = U8x64::from_array(arr); + assert_eq!(v.reduce_min(), 5); + assert_eq!(v.reduce_max(), 200); + } + + #[test] + fn u8x64_reduce_sum_wraps() { + // 64 * 4 = 256 which wraps to 0 + assert_eq!(U8x64::splat(4).reduce_sum(), 0u8); + } + + #[test] + fn u8x64_sum_bytes_u64_no_wrap() { + assert_eq!(U8x64::splat(255).sum_bytes_u64(), 64 * 255); + } + + #[test] + fn u8x64_simd_min_max() { + let a = U8x64::splat(100); + let b = U8x64::splat(50); + assert!(a.simd_min(b).to_array().iter().all(|&x| x == 50)); + assert!(a.simd_max(b).to_array().iter().all(|&x| x == 100)); + } + + #[test] + fn u8x64_saturating_add() { + let a = U8x64::splat(200); + let b = U8x64::splat(100); + assert!(a.saturating_add(b).to_array().iter().all(|&x| x == 255)); + } + + #[test] + fn u8x64_saturating_sub() { + let a = U8x64::splat(10); + let b = U8x64::splat(20); + assert!(a.saturating_sub(b).to_array().iter().all(|&x| x == 0)); + } + + #[test] + fn u8x64_pairwise_avg() { + // (10 + 11 + 1) / 2 = 11 + let a = U8x64::splat(10); + let b = U8x64::splat(11); + assert!(a.pairwise_avg(b).to_array().iter().all(|&x| x == 11)); + // (0 + 1 + 1) / 2 = 1 + let c = U8x64::splat(0); + let d = U8x64::splat(1); + assert!(c.pairwise_avg(d).to_array().iter().all(|&x| x == 1)); + } + + #[test] + fn u8x64_cmpeq_mask_all_eq() { + assert_eq!(U8x64::splat(5).cmpeq_mask(U8x64::splat(5)), u64::MAX); + } + + #[test] + fn u8x64_cmpeq_mask_none_eq() { + assert_eq!(U8x64::splat(1).cmpeq_mask(U8x64::splat(2)), 0u64); + } + + #[test] + fn u8x64_cmpgt_mask() { + assert_eq!(U8x64::splat(10).cmpgt_mask(U8x64::splat(5)), u64::MAX); + assert_eq!(U8x64::splat(5).cmpgt_mask(U8x64::splat(10)), 0u64); + } + + #[test] + fn u8x64_movemask() { + assert_eq!(U8x64::splat(128).movemask(), u64::MAX); + assert_eq!(U8x64::splat(127).movemask(), 0u64); + } + + #[test] + fn u8x64_shr_epi16() { + // 0x10 >> 4 = 0x01 in the low byte of each u16 lane. + let v = U8x64::splat(0x10); + assert_eq!(v.shr_epi16(4).to_array()[0], 0x01); + } + + #[test] + fn u8x64_shl_epi16() { + let v = U8x64::splat(0x01); + assert_eq!(v.shl_epi16(4).to_array()[0], 0x10); + } + + #[test] + fn u8x64_nibble_popcount_lut() { + let lut = U8x64::nibble_popcount_lut(); + let arr = lut.to_array(); + assert_eq!(arr[0], 0); + assert_eq!(arr[1], 1); + assert_eq!(arr[3], 2); + assert_eq!(arr[7], 3); + assert_eq!(arr[15], 4); + // Verify second repetition at offset 16. + assert_eq!(arr[16], 0); + assert_eq!(arr[31], 4); + } + + // ── U8x32 tests ───────────────────────────────────────────────── + + #[test] + fn u8x32_splat_to_array() { + let v = U8x32::splat(42); + assert!(v.to_array().iter().all(|&b| b == 42)); + } + + #[test] + fn u8x32_from_slice_roundtrip() { + let src: [u8; 32] = core::array::from_fn(|i| i as u8); + let mut dst = [0u8; 32]; + U8x32::from_slice(&src).copy_to_slice(&mut dst); + assert_eq!(src, dst); + } + + #[test] + fn u8x32_reduce_min_max() { + let mut arr = [100u8; 32]; + arr[0] = 5; + arr[31] = 200; + let v = U8x32::from_array(arr); + assert_eq!(v.reduce_min(), 5); + assert_eq!(v.reduce_max(), 200); + } + + #[test] + fn u8x32_sum_bytes_u64_no_wrap() { + assert_eq!(U8x32::splat(255).sum_bytes_u64(), 32 * 255); + } + + #[test] + fn u8x32_simd_min_max() { + let a = U8x32::splat(100); + let b = U8x32::splat(50); + assert!(a.simd_min(b).to_array().iter().all(|&x| x == 50)); + assert!(a.simd_max(b).to_array().iter().all(|&x| x == 100)); + } + + #[test] + fn u8x32_saturating_add() { + let a = U8x32::splat(200); + let b = U8x32::splat(100); + assert!(a.saturating_add(b).to_array().iter().all(|&x| x == 255)); + } + + #[test] + fn u8x32_saturating_sub() { + let a = U8x32::splat(10); + let b = U8x32::splat(20); + assert!(a.saturating_sub(b).to_array().iter().all(|&x| x == 0)); + } + + #[test] + fn u8x32_pairwise_avg() { + let a = U8x32::splat(10); + let b = U8x32::splat(11); + assert!(a.pairwise_avg(b).to_array().iter().all(|&x| x == 11)); + } + + #[test] + fn u8x32_cmpeq_mask_all_eq() { + assert_eq!(U8x32::splat(5).cmpeq_mask(U8x32::splat(5)), u32::MAX); + } + + #[test] + fn u8x32_cmpgt_mask() { + assert_eq!(U8x32::splat(10).cmpgt_mask(U8x32::splat(5)), u32::MAX); + assert_eq!(U8x32::splat(5).cmpgt_mask(U8x32::splat(10)), 0u32); + } + + #[test] + fn u8x32_movemask() { + assert_eq!(U8x32::splat(128).movemask(), u32::MAX); + assert_eq!(U8x32::splat(127).movemask(), 0u32); + } + + #[test] + fn u8x32_shr_epi16() { + let v = U8x32::splat(0x10); + assert_eq!(v.shr_epi16(4).to_array()[0], 0x01); + } + + #[test] + fn u8x32_shl_epi16() { + let v = U8x32::splat(0x01); + assert_eq!(v.shl_epi16(4).to_array()[0], 0x10); + } + + #[test] + fn u8x32_nibble_popcount_lut() { + let lut = U8x32::nibble_popcount_lut(); + let arr = lut.to_array(); + assert_eq!(arr[0], 0); + assert_eq!(arr[3], 2); + assert_eq!(arr[15], 4); + assert_eq!(arr[16], 0); + assert_eq!(arr[31], 4); + } +} diff --git a/src/simd_nightly/u_word_types.rs b/src/simd_nightly/u_word_types.rs new file mode 100644 index 00000000..a20bc3ce --- /dev/null +++ b/src/simd_nightly/u_word_types.rs @@ -0,0 +1,585 @@ +//! U16x32 / U32x16 / U32x8 / U64x8 / U64x4 portable-simd wrappers — round-3-portable-simd agent #4. +#![cfg(feature = "nightly-simd")] + +use core::simd::cmp::{SimdOrd, SimdPartialEq, SimdPartialOrd}; +use core::simd::num::SimdUint; +use core::simd::{u16x32, u32x16, u32x8, u64x4, u64x8}; + +// ════════════════════════════════════════════════════════════════════ +// U64x8 — 8-lane u64 +// ════════════════════════════════════════════════════════════════════ + +/// 8-lane `u64` SIMD vector backed by `core::simd::u64x8`. +/// +/// Also used as the return type of `F64x8::to_bits`. +#[derive(Copy, Clone, Debug, PartialEq)] +#[repr(transparent)] +pub struct U64x8(pub u64x8); + +impl U64x8 { + pub const LANES: usize = 8; + + #[inline(always)] + pub fn splat(v: u64) -> Self { + Self(u64x8::splat(v)) + } + + #[inline(always)] + pub fn from_slice(s: &[u64]) -> Self { + assert!(s.len() >= 8, "U64x8::from_slice needs >=8 elements"); + Self(u64x8::from_slice(s)) + } + + #[inline(always)] + pub fn from_array(arr: [u64; 8]) -> Self { + Self(u64x8::from_array(arr)) + } + + #[inline(always)] + pub fn to_array(self) -> [u64; 8] { + self.0.to_array() + } + + #[inline(always)] + pub fn copy_to_slice(self, s: &mut [u64]) { + assert!(s.len() >= 8, "U64x8::copy_to_slice needs >=8 elements"); + self.0.copy_to_slice(s); + } + + // ── Reductions ──────────────────────────────────────────────── + + #[inline(always)] + pub fn reduce_sum(self) -> u64 { + self.0.reduce_sum() + } + + #[inline(always)] + pub fn reduce_min(self) -> u64 { + self.0.reduce_min() + } + + #[inline(always)] + pub fn reduce_max(self) -> u64 { + self.0.reduce_max() + } + + // ── Lane-wise min/max ───────────────────────────────────────── + + #[inline(always)] + pub fn simd_min(self, other: Self) -> Self { + Self(self.0.simd_min(other.0)) + } + + #[inline(always)] + pub fn simd_max(self, other: Self) -> Self { + Self(self.0.simd_max(other.0)) + } + + // ── Compare -> bitmask ──────────────────────────────────────── + + /// Per-lane equality. Returns an 8-bit bitmask (bit i = 1 iff lane i equal). + #[inline(always)] + pub fn cmpeq_mask(self, other: Self) -> u8 { + self.0.simd_eq(other.0).to_bitmask() as u8 + } + + /// Per-lane unsigned greater-than. Returns an 8-bit bitmask. + #[inline(always)] + pub fn cmpgt_mask(self, other: Self) -> u8 { + self.0.simd_gt(other.0).to_bitmask() as u8 + } +} + +impl Default for U64x8 { + #[inline(always)] + fn default() -> Self { + Self::splat(0) + } +} + +// ════════════════════════════════════════════════════════════════════ +// U64x4 — 4-lane u64 (companion for F64x4::to_bits) +// ════════════════════════════════════════════════════════════════════ + +/// 4-lane `u64` SIMD vector backed by `core::simd::u64x4`. +/// +/// Return type of `F64x4::to_bits`. +#[derive(Copy, Clone, Debug, PartialEq)] +#[repr(transparent)] +pub struct U64x4(pub u64x4); + +impl U64x4 { + pub const LANES: usize = 4; + + #[inline(always)] + pub fn splat(v: u64) -> Self { + Self(u64x4::splat(v)) + } + + #[inline(always)] + pub fn from_slice(s: &[u64]) -> Self { + assert!(s.len() >= 4, "U64x4::from_slice needs >=4 elements"); + Self(u64x4::from_slice(s)) + } + + #[inline(always)] + pub fn from_array(arr: [u64; 4]) -> Self { + Self(u64x4::from_array(arr)) + } + + #[inline(always)] + pub fn to_array(self) -> [u64; 4] { + self.0.to_array() + } + + #[inline(always)] + pub fn copy_to_slice(self, s: &mut [u64]) { + assert!(s.len() >= 4, "U64x4::copy_to_slice needs >=4 elements"); + self.0.copy_to_slice(s); + } + + // ── Reductions ──────────────────────────────────────────────── + + #[inline(always)] + pub fn reduce_sum(self) -> u64 { + self.0.reduce_sum() + } + + #[inline(always)] + pub fn reduce_min(self) -> u64 { + self.0.reduce_min() + } + + #[inline(always)] + pub fn reduce_max(self) -> u64 { + self.0.reduce_max() + } + + // ── Lane-wise min/max ───────────────────────────────────────── + + #[inline(always)] + pub fn simd_min(self, other: Self) -> Self { + Self(self.0.simd_min(other.0)) + } + + #[inline(always)] + pub fn simd_max(self, other: Self) -> Self { + Self(self.0.simd_max(other.0)) + } + + // ── Compare -> bitmask ──────────────────────────────────────── + + /// Per-lane equality. Returns an 8-bit value (4 bits used; bit i = 1 iff lane i equal). + #[inline(always)] + pub fn cmpeq_mask(self, other: Self) -> u8 { + self.0.simd_eq(other.0).to_bitmask() as u8 + } + + /// Per-lane unsigned greater-than. Returns an 8-bit value (4 bits used). + #[inline(always)] + pub fn cmpgt_mask(self, other: Self) -> u8 { + self.0.simd_gt(other.0).to_bitmask() as u8 + } +} + +impl Default for U64x4 { + #[inline(always)] + fn default() -> Self { + Self::splat(0) + } +} + +// ════════════════════════════════════════════════════════════════════ +// U32x8 — 8-lane u32 (companion for F32x8::to_bits) +// ════════════════════════════════════════════════════════════════════ + +/// 8-lane `u32` SIMD vector backed by `core::simd::u32x8`. +/// +/// Return type of `F32x8::to_bits`. +#[derive(Copy, Clone, Debug, PartialEq)] +#[repr(transparent)] +pub struct U32x8(pub u32x8); + +impl U32x8 { + pub const LANES: usize = 8; + + #[inline(always)] + pub fn splat(v: u32) -> Self { + Self(u32x8::splat(v)) + } + + #[inline(always)] + pub fn from_slice(s: &[u32]) -> Self { + assert!(s.len() >= 8, "U32x8::from_slice needs >=8 elements"); + Self(u32x8::from_slice(s)) + } + + #[inline(always)] + pub fn from_array(arr: [u32; 8]) -> Self { + Self(u32x8::from_array(arr)) + } + + #[inline(always)] + pub fn to_array(self) -> [u32; 8] { + self.0.to_array() + } + + #[inline(always)] + pub fn copy_to_slice(self, s: &mut [u32]) { + assert!(s.len() >= 8, "U32x8::copy_to_slice needs >=8 elements"); + self.0.copy_to_slice(s); + } + + // ── Reductions ──────────────────────────────────────────────── + + #[inline(always)] + pub fn reduce_sum(self) -> u32 { + self.0.reduce_sum() + } + + #[inline(always)] + pub fn reduce_min(self) -> u32 { + self.0.reduce_min() + } + + #[inline(always)] + pub fn reduce_max(self) -> u32 { + self.0.reduce_max() + } + + // ── Lane-wise min/max ───────────────────────────────────────── + + #[inline(always)] + pub fn simd_min(self, other: Self) -> Self { + Self(self.0.simd_min(other.0)) + } + + #[inline(always)] + pub fn simd_max(self, other: Self) -> Self { + Self(self.0.simd_max(other.0)) + } + + // ── Compare -> bitmask ──────────────────────────────────────── + + /// Per-lane equality. Returns an 8-bit bitmask (bit i = 1 iff lane i equal). + #[inline(always)] + pub fn cmpeq_mask(self, other: Self) -> u8 { + self.0.simd_eq(other.0).to_bitmask() as u8 + } + + /// Per-lane unsigned greater-than. Returns an 8-bit bitmask. + #[inline(always)] + pub fn cmpgt_mask(self, other: Self) -> u8 { + self.0.simd_gt(other.0).to_bitmask() as u8 + } +} + +impl Default for U32x8 { + #[inline(always)] + fn default() -> Self { + Self::splat(0) + } +} + +// ════════════════════════════════════════════════════════════════════ +// U32x16 — 16-lane u32 +// ════════════════════════════════════════════════════════════════════ + +/// 16-lane `u32` SIMD vector backed by `core::simd::u32x16`. +/// +/// Also used as the return type of `F32x16::to_bits`. +#[derive(Copy, Clone, Debug, PartialEq)] +#[repr(transparent)] +pub struct U32x16(pub u32x16); + +impl U32x16 { + pub const LANES: usize = 16; + + #[inline(always)] + pub fn splat(v: u32) -> Self { + Self(u32x16::splat(v)) + } + + #[inline(always)] + pub fn from_slice(s: &[u32]) -> Self { + assert!(s.len() >= 16, "U32x16::from_slice needs >=16 elements"); + Self(u32x16::from_slice(s)) + } + + #[inline(always)] + pub fn from_array(arr: [u32; 16]) -> Self { + Self(u32x16::from_array(arr)) + } + + #[inline(always)] + pub fn to_array(self) -> [u32; 16] { + self.0.to_array() + } + + #[inline(always)] + pub fn copy_to_slice(self, s: &mut [u32]) { + assert!(s.len() >= 16, "U32x16::copy_to_slice needs >=16 elements"); + self.0.copy_to_slice(s); + } + + // ── Reductions ──────────────────────────────────────────────── + + #[inline(always)] + pub fn reduce_sum(self) -> u32 { + self.0.reduce_sum() + } + + #[inline(always)] + pub fn reduce_min(self) -> u32 { + self.0.reduce_min() + } + + #[inline(always)] + pub fn reduce_max(self) -> u32 { + self.0.reduce_max() + } + + // ── Lane-wise min/max ───────────────────────────────────────── + + #[inline(always)] + pub fn simd_min(self, other: Self) -> Self { + Self(self.0.simd_min(other.0)) + } + + #[inline(always)] + pub fn simd_max(self, other: Self) -> Self { + Self(self.0.simd_max(other.0)) + } + + // ── Compare -> bitmask ──────────────────────────────────────── + + /// Per-lane equality. Returns a 16-bit bitmask (bit i = 1 iff lane i equal). + #[inline(always)] + pub fn cmpeq_mask(self, other: Self) -> u16 { + self.0.simd_eq(other.0).to_bitmask() as u16 + } + + /// Per-lane unsigned greater-than. Returns a 16-bit bitmask. + #[inline(always)] + pub fn cmpgt_mask(self, other: Self) -> u16 { + self.0.simd_gt(other.0).to_bitmask() as u16 + } +} + +impl Default for U32x16 { + #[inline(always)] + fn default() -> Self { + Self::splat(0) + } +} + +// ════════════════════════════════════════════════════════════════════ +// U16x32 — 32-lane u16 +// ════════════════════════════════════════════════════════════════════ + +/// 32-lane `u16` SIMD vector backed by `core::simd::u16x32`. +/// +/// API mirrors `simd_avx512::U16x32`. Miri-executable. +#[derive(Copy, Clone, Debug, PartialEq)] +#[repr(transparent)] +pub struct U16x32(pub u16x32); + +impl U16x32 { + pub const LANES: usize = 32; + + #[inline(always)] + pub fn splat(v: u16) -> Self { + Self(u16x32::splat(v)) + } + + #[inline(always)] + pub fn from_slice(s: &[u16]) -> Self { + assert!(s.len() >= 32, "U16x32::from_slice needs >=32 elements"); + Self(u16x32::from_slice(s)) + } + + #[inline(always)] + pub fn from_array(arr: [u16; 32]) -> Self { + Self(u16x32::from_array(arr)) + } + + #[inline(always)] + pub fn to_array(self) -> [u16; 32] { + self.0.to_array() + } + + #[inline(always)] + pub fn copy_to_slice(self, s: &mut [u16]) { + assert!(s.len() >= 32, "U16x32::copy_to_slice needs >=32 elements"); + self.0.copy_to_slice(s); + } + + // ── Reductions ──────────────────────────────────────────────── + + /// Wrapping horizontal sum of all 32 lanes (result is u16, wraps on overflow). + #[inline(always)] + pub fn reduce_sum(self) -> u16 { + self.0.reduce_sum() + } + + #[inline(always)] + pub fn reduce_min(self) -> u16 { + self.0.reduce_min() + } + + #[inline(always)] + pub fn reduce_max(self) -> u16 { + self.0.reduce_max() + } + + // ── Lane-wise min/max ───────────────────────────────────────── + + #[inline(always)] + pub fn simd_min(self, other: Self) -> Self { + Self(self.0.simd_min(other.0)) + } + + #[inline(always)] + pub fn simd_max(self, other: Self) -> Self { + Self(self.0.simd_max(other.0)) + } + + // ── Saturating arithmetic ───────────────────────────────────── + + #[inline(always)] + pub fn saturating_add(self, other: Self) -> Self { + Self(self.0.saturating_add(other.0)) + } + + #[inline(always)] + pub fn saturating_sub(self, other: Self) -> Self { + Self(self.0.saturating_sub(other.0)) + } + + // ── Compare -> bitmask ──────────────────────────────────────── + + /// Per-lane equality. Returns a 32-bit bitmask (bit i = 1 iff lane i equal). + #[inline(always)] + pub fn cmpeq_mask(self, other: Self) -> u32 { + self.0.simd_eq(other.0).to_bitmask() as u32 + } + + /// Per-lane unsigned greater-than. Returns a 32-bit bitmask. + #[inline(always)] + pub fn cmpgt_mask(self, other: Self) -> u32 { + self.0.simd_gt(other.0).to_bitmask() as u32 + } +} + +impl Default for U16x32 { + #[inline(always)] + fn default() -> Self { + Self::splat(0) + } +} + +// ════════════════════════════════════════════════════════════════════ +// Tests +// ════════════════════════════════════════════════════════════════════ + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn u16x32_splat_reduce() { + let v = U16x32::splat(3); + assert_eq!(v.reduce_min(), 3); + assert_eq!(v.reduce_max(), 3); + } + + #[test] + fn u16x32_cmpeq_mask_all_equal() { + let a = U16x32::splat(10); + let b = U16x32::splat(10); + assert_eq!(a.cmpeq_mask(b), u32::MAX); + } + + #[test] + fn u16x32_saturating_add_clamps() { + let a = U16x32::splat(60000u16); + let b = U16x32::splat(10000u16); + let c = a.saturating_add(b); + assert!(c.to_array().iter().all(|&v| v == u16::MAX)); + } + + #[test] + fn u32x16_splat_reduce() { + let v = U32x16::splat(7); + assert_eq!(v.reduce_sum(), 112u32); // 16 * 7 + assert_eq!(v.reduce_min(), 7); + assert_eq!(v.reduce_max(), 7); + } + + #[test] + fn u32x16_cmpgt_mask() { + let mut arr = [0u32; 16]; + for (i, x) in arr.iter_mut().enumerate() { + *x = i as u32; + } + let v = U32x16::from_array(arr); + let threshold = U32x16::splat(7); + // Lanes 8..15 are > 7 -> bits 8..15 set -> 0xFF00 + assert_eq!(v.cmpgt_mask(threshold), 0xFF00u16); + } + + #[test] + fn u32x8_splat_reduce() { + let v = U32x8::splat(5); + assert_eq!(v.reduce_sum(), 40u32); // 8 * 5 + assert_eq!(v.reduce_min(), 5); + assert_eq!(v.reduce_max(), 5); + } + + #[test] + fn u32x8_cmpeq_mask() { + let a = U32x8::from_array([1, 2, 1, 2, 1, 2, 1, 2]); + let b = U32x8::splat(1); + // Lanes 0,2,4,6 equal -> bits 0,2,4,6 -> 0b01010101 = 0x55 + assert_eq!(a.cmpeq_mask(b), 0x55u8); + } + + #[test] + fn u64x8_splat_reduce() { + let v = U64x8::splat(100); + assert_eq!(v.reduce_sum(), 800u64); // 8 * 100 + assert_eq!(v.reduce_min(), 100); + assert_eq!(v.reduce_max(), 100); + } + + #[test] + fn u64x8_cmpeq_mask_all() { + let a = U64x8::splat(42); + let b = U64x8::splat(42); + assert_eq!(a.cmpeq_mask(b), 0xFFu8); + } + + #[test] + fn u64x4_splat_reduce() { + let v = U64x4::splat(9); + assert_eq!(v.reduce_sum(), 36u64); // 4 * 9 + assert_eq!(v.reduce_min(), 9); + assert_eq!(v.reduce_max(), 9); + } + + #[test] + fn u64x4_cmpeq_mask_partial() { + let a = U64x4::from_array([1, 2, 1, 2]); + let b = U64x4::splat(1); + // Lanes 0,2 equal -> bits 0,2 -> 0b0101 = 5 + assert_eq!(a.cmpeq_mask(b), 0x05u8); + } + + #[test] + fn u64x4_cmpgt_mask() { + let a = U64x4::from_array([10, 1, 10, 1]); + let b = U64x4::splat(5); + // Lanes 0,2 > 5 -> bits 0,2 -> 0b0101 = 5 + assert_eq!(a.cmpgt_mask(b), 0x05u8); + } +}