diff --git a/src/ggml-bitnet-mad.cpp b/src/ggml-bitnet-mad.cpp index 4ba9d6509..7d16d0a31 100644 --- a/src/ggml-bitnet-mad.cpp +++ b/src/ggml-bitnet-mad.cpp @@ -11,7 +11,9 @@ #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) #define QK_I2_S 128 #elif defined(__ARM_NEON) -#define QK_I2_S 64 +#define QK_I2_S 128 // <--- Replaced legacy 64 with 128 +#else +#define QK_I2_S 128 // Fallback for environments without hardware acceleration #endif #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) @@ -174,14 +176,18 @@ size_t quantize_i2_s(const float * src, void * dst, int64_t nrow, int64_t n_per_ // q8 -> 0, 1, 2 // | | | // -1, 0, 1 - + // ==================================================================== + // [Memory Packing Standardization] + // Aligned the packing stride to 32 to be 100% identical to the AVX2 (PC) memory layout! + // The previous code packed based on a 16-stride (assuming QK=64), which caused tensor corruption during decoding. + // ==================================================================== uint8_t* i2_weight = (uint8_t*)dst; for (int i = 0; i < n / QK_I2_S; i++) { for (int j = 0; j < QK_I2_S; j++) { - int group_idx = j / 16; - int group_pos = j % 16; + int group_idx = j / 32; // <--- Changed from 16 to 32 to sync with AVX2 layout + int group_pos = j % 32; // <--- Changed from 16 to 32 uint8_t temp = (q8[i * QK_I2_S + j] << (6 - 2 * group_idx)); - i2_weight[i * 16 + group_pos] |= temp; + i2_weight[i * 32 + group_pos] |= temp; // <--- Changed from 16 to 32 } } @@ -295,118 +301,122 @@ void ggml_vec_dot_i2_i8_s_1x1(int n, float * s, size_t bs, const void * vx, size s[row] = (float)sumi; } #elif defined(__ARM_NEON) - const uint8_t * x = (uint8_t *)vx; - const int8_t * y = (int8_t *)vy; + // ==================================================================== + // [Path 2] Mobile Environment: ARM NEON / DotProd Acceleration + // ==================================================================== + const uint8_t * x = (uint8_t *)vx; + const int8_t * y = (int8_t *)vy; - const int nb = n / QK_I2_S; - const int group32_num = nb / 32; - const int la_num = nb % 32; - const int groupla_num = nb % 32 != 0 ? 1 : 0; + // [Core Fix] GGUF files are typically packed on x86, which enforces QK=128. + // Removed the previous hardcoded loop unrolling (group32_num, la_num) + // that assumed QK=64, preventing memory offset corruption (Word Salad bug). + // Refactored to a clean block-level loop (nb) to strictly match the 128 format. + const int QK = 128; + const int nb = n / QK; - const uint8x16_t mask = vdupq_n_u8(3); + const uint8x16_t mask = vdupq_n_u8(0x03); - // 处理多列,nrc表示要处理的列数 for (int row = 0; row < nrc; row++) { int32x4_t accu = vdupq_n_s32(0); - - // 计算当前行的x指针偏移 - const uint8_t * x_row = x + row * bx / 4; - - for (int i=0; i < group32_num; i++) { + const uint8_t * x_row = x + (row * bx) / 4; + + for (int b = 0; b < nb; b++) { + // Based on QK=128: 1 block weight = 32 bytes, 1 block activation (y) = 128 bytes + const uint8_t * px = x_row + b * 32; + const int8_t * py = y + b * QK; + + // Split the 32-byte weights into two 16-byte chunks to fit NEON registers. + for (int j = 0; j < 2; j++) { + int k = j * 16; + uint8x16_t xb = vld1q_u8(px + k); + + // Extract 2-bits from MSB to LSB (100% identical to AVX2 unpacking logic) + int8x16_t v0 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(xb, 6), mask)); + int8x16_t v1 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(xb, 4), mask)); + int8x16_t v2 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(xb, 2), mask)); + int8x16_t v3 = vreinterpretq_s8_u8(vandq_u8(xb, mask)); + + // Interleaved memory fetch jumping by 32 (Matching AVX2 layout) + int8x16_t y0 = vld1q_s8(py + k + 0*32); + int8x16_t y1 = vld1q_s8(py + k + 1*32); + int8x16_t y2 = vld1q_s8(py + k + 2*32); + int8x16_t y3 = vld1q_s8(py + k + 3*32); #if defined(__ARM_FEATURE_DOTPROD) - + // Hardware Acceleration (Devices supporting DotProd) + accu = vdotq_s32(accu, v0, y0); + accu = vdotq_s32(accu, v1, y1); + accu = vdotq_s32(accu, v2, y2); + accu = vdotq_s32(accu, v3, y3); #else - int16x8_t accu32 = vdupq_n_s16(0); -#endif - for (int j=0; j < 32; j++) { - uint8x16_t xq8_3 = vld1q_u8(x_row + i * 32 * 16 + j * 16); - uint8x16_t xq8_2 = vshrq_n_u8(xq8_3, 2); - uint8x16_t xq8_1 = vshrq_n_u8(xq8_3, 4); - uint8x16_t xq8_0 = vshrq_n_u8(xq8_3, 6); - - int8x16_t q8_0 = vreinterpretq_s8_u8(vandq_u8(xq8_0, mask)); - int8x16_t q8_1 = vreinterpretq_s8_u8(vandq_u8(xq8_1, mask)); - int8x16_t q8_2 = vreinterpretq_s8_u8(vandq_u8(xq8_2, mask)); - int8x16_t q8_3 = vreinterpretq_s8_u8(vandq_u8(xq8_3, mask)); - - const int8x16_t yq8_0 = vld1q_s8(y + i * 32 * 64 + j * 64 + 0); - const int8x16_t yq8_1 = vld1q_s8(y + i * 32 * 64 + j * 64 + 16); - const int8x16_t yq8_2 = vld1q_s8(y + i * 32 * 64 + j * 64 + 32); - const int8x16_t yq8_3 = vld1q_s8(y + i * 32 * 64 + j * 64 + 48); - -#if defined(__ARM_FEATURE_DOTPROD) - accu = vdotq_s32(accu, q8_0, yq8_0); - accu = vdotq_s32(accu, q8_1, yq8_1); - accu = vdotq_s32(accu, q8_2, yq8_2); - accu = vdotq_s32(accu, q8_3, yq8_3); -#else - accu32 = vmlal_s8(accu32, vget_low_s8(q8_0), vget_low_s8(yq8_0)); - accu32 = vmlal_s8(accu32, vget_high_s8(q8_0), vget_high_s8(yq8_0)); - accu32 = vmlal_s8(accu32, vget_low_s8(q8_1), vget_low_s8(yq8_1)); - accu32 = vmlal_s8(accu32, vget_high_s8(q8_1), vget_high_s8(yq8_1)); - accu32 = vmlal_s8(accu32, vget_low_s8(q8_2), vget_low_s8(yq8_2)); - accu32 = vmlal_s8(accu32, vget_high_s8(q8_2), vget_high_s8(yq8_2)); - accu32 = vmlal_s8(accu32, vget_low_s8(q8_3), vget_low_s8(yq8_3)); - accu32 = vmlal_s8(accu32, vget_high_s8(q8_3), vget_high_s8(yq8_3)); + // FMA Fallback for devices without DotProd support + int16x8_t accula = vdupq_n_s16(0); + accula = vmlal_s8(accula, vget_low_s8(v0), vget_low_s8(y0)); + accula = vmlal_s8(accula, vget_high_s8(v0), vget_high_s8(y0)); + accula = vmlal_s8(accula, vget_low_s8(v1), vget_low_s8(y1)); + accula = vmlal_s8(accula, vget_high_s8(v1), vget_high_s8(y1)); + accula = vmlal_s8(accula, vget_low_s8(v2), vget_low_s8(y2)); + accula = vmlal_s8(accula, vget_high_s8(v2), vget_high_s8(y2)); + accula = vmlal_s8(accula, vget_low_s8(v3), vget_low_s8(y3)); + accula = vmlal_s8(accula, vget_high_s8(v3), vget_high_s8(y3)); + + accu = vaddq_s32(accu, vmovl_s16(vget_low_s16(accula))); + accu = vaddq_s32(accu, vmovl_high_s16(accula)); #endif } - -#if defined(__ARM_FEATURE_DOTPROD) - -#else - accu = vaddq_s32(accu, vmovl_s16(vget_low_s16(accu32))); - accu = vaddq_s32(accu, vmovl_high_s16(accu32)); -#endif } - - for (int i = 0; i < groupla_num; i++){ -#if defined(__ARM_FEATURE_DOTPROD) - + int64_t sumi = vaddvq_s32(accu); + s[row] = (float)sumi; + } #else - int16x8_t accula = vdupq_n_s16(0); -#endif - for (int j = 0; j < la_num; j++) { - uint8x16_t xq8_3 = vld1q_u8(x_row + group32_num * 32 * 16 + j * 16); - uint8x16_t xq8_2 = vshrq_n_u8(xq8_3, 2); - uint8x16_t xq8_1 = vshrq_n_u8(xq8_3, 4); - uint8x16_t xq8_0 = vshrq_n_u8(xq8_3, 6); + // ==================================================================== + // [Path 3] Pure C++ Scalar Fallback + // Environment: No hardware acceleration or explicitly disabled (-U__ARM_NEON) + // ==================================================================== + const uint8_t * x_ptr = (const uint8_t *)vx; + const int8_t * y_ptr = (const int8_t *)vy; + + // [Core Fix] Strictly enforce QK=128 to match the x86 GGUF packing standard. + // This prevents memory misalignment and out-of-bounds access that occurs + // when falling back to a scalar path that falsely assumes QK=64. + const int qk = 128; + const int nb = n / qk; - int8x16_t q8_0 = vreinterpretq_s8_u8(vandq_u8(xq8_0, mask)); - int8x16_t q8_1 = vreinterpretq_s8_u8(vandq_u8(xq8_1, mask)); - int8x16_t q8_2 = vreinterpretq_s8_u8(vandq_u8(xq8_2, mask)); - int8x16_t q8_3 = vreinterpretq_s8_u8(vandq_u8(xq8_3, mask)); - - const int8x16_t yq8_0 = vld1q_s8(y + group32_num * 32 * 64 + j * 64 + 0); - const int8x16_t yq8_1 = vld1q_s8(y + group32_num * 32 * 64 + j * 64 + 16); - const int8x16_t yq8_2 = vld1q_s8(y + group32_num * 32 * 64 + j * 64 + 32); - const int8x16_t yq8_3 = vld1q_s8(y + group32_num * 32 * 64 + j * 64 + 48); - -#if defined(__ARM_FEATURE_DOTPROD) - accu = vdotq_s32(accu, q8_0, yq8_0); - accu = vdotq_s32(accu, q8_1, yq8_1); - accu = vdotq_s32(accu, q8_2, yq8_2); - accu = vdotq_s32(accu, q8_3, yq8_3); -#else - accula = vmlal_s8(accula, vget_low_s8(q8_0), vget_low_s8(yq8_0)); - accula = vmlal_s8(accula, vget_high_s8(q8_0), vget_high_s8(yq8_0)); - accula = vmlal_s8(accula, vget_low_s8(q8_1), vget_low_s8(yq8_1)); - accula = vmlal_s8(accula, vget_high_s8(q8_1), vget_high_s8(yq8_1)); - accula = vmlal_s8(accula, vget_low_s8(q8_2), vget_low_s8(yq8_2)); - accula = vmlal_s8(accula, vget_high_s8(q8_2), vget_high_s8(yq8_2)); - accula = vmlal_s8(accula, vget_low_s8(q8_3), vget_low_s8(yq8_3)); - accula = vmlal_s8(accula, vget_high_s8(q8_3), vget_high_s8(yq8_3)); -#endif + for (int row = 0; row < nrc; row++) { + // Use int32_t for the accumulator to safely prevent 16-bit overflow + int32_t sumi = 0; + const uint8_t * x_row = x_ptr + row * (bx / 4); + + for (int b = 0; b < nb; b++) { + const uint8_t * px = x_row + b * 32; // 1 block of i2_s weights = 32 bytes + const int8_t * py = y_ptr + b * 128; // 1 block of activations = 128 bytes + + for (int k = 0; k < 32; k++) { + uint8_t xb = px[k]; + + // Unpack 2-bit values from MSB to LSB. + // This extraction order is 100% mathematically identical to + // the '_mm256_srli_epi16' logical shifts used in the AVX2 kernel. + int v0 = (xb >> 6) & 0x03; // bits 7-6 + int v1 = (xb >> 4) & 0x03; // bits 5-4 + int v2 = (xb >> 2) & 0x03; // bits 3-2 + int v3 = xb & 0x03; // bits 1-0 + + // [Crucial Math Alignment] + // 1. Directly multiply the extracted values (0, 1, 2) without applying + // a (-1) offset. The zero-mean property of the activations allows + // the offset correction to be handled implicitly in upper layers. + // 2. Fetch 'y' using a 32-stride interleaving layout to match the + // AVX2 packing standard. + sumi += v0 * py[k + 0*32]; + sumi += v1 * py[k + 1*32]; + sumi += v2 * py[k + 2*32]; + sumi += v3 * py[k + 3*32]; } -#if defined(__ARM_FEATURE_DOTPROD) - -#else - accu = vaddq_s32(accu, vmovl_s16(vget_low_s16(accula))); - accu = vaddq_s32(accu, vmovl_high_s16(accula)); -#endif } - int sumi = vaddlvq_s32(accu); - s[row] = (float)sumi; + // Do NOT apply the dequantization scale here. + // The scale is applied later in the ggml_mul_mat graph node. + s[row] = (float)sumi; } #endif } @@ -632,156 +642,77 @@ void ggml_vec_dot_i2_i8_s_1xN(int n, float * s, size_t bs, const void * vx, size } } #elif defined(__ARM_NEON) - const uint8_t * x = (uint8_t *)vx; - const int8_t * y = (int8_t *)vy; + // ==================================================================== + // [Mobile Environment: ARM NEON / DotProd] - 1xN Parallel Kernel + // Processes PARALLEL_SIZE rows of X against 1 common row/col of Y. + // ==================================================================== + const uint8_t * x = (const uint8_t *)vx; + const int8_t * y = (const int8_t *)vy; - const int nb = n / QK_I2_S; - const int group32_num = nb / 32; - const int la_num = nb % 32; - const int groupla_num = nb % 32 != 0 ? 1 : 0; - - const uint8x16_t mask = vdupq_n_u8(3); + // [Core Fix] Enforce QK=128 to match the standard GGUF memory layout. + // Replaces legacy QK=64 hardcoded loop unrolling to prevent memory corruption. + const int QK = 128; + const int nb = n / QK; - // 处理多行,nrc表示要处理的行数 - for (int row = 0; row < nrc; row += PARALLEL_SIZE) { + const uint8x16_t mask = vdupq_n_u8(0x03); + for (int row = 0; row < nrc; row += PARALLEL_SIZE) { int32x4_t accu[PARALLEL_SIZE]; const uint8_t * x_row[PARALLEL_SIZE]; - + for (int rb = 0; rb < PARALLEL_SIZE; rb++) { accu[rb] = vdupq_n_s32(0); x_row[rb] = x + (row + rb) * bx / 4; } - for (int i = 0; i < group32_num; i++) { -#if defined(__ARM_FEATURE_DOTPROD) + for (int b = 0; b < nb; b++) { + const int8_t * py = y + b * QK; -#else - int16x8_t accu32[PARALLEL_SIZE]; - for (int rb = 0; rb < PARALLEL_SIZE; rb++) { - accu32[rb] = vdupq_n_s16(0); - } -#endif - const uint8_t * px[PARALLEL_SIZE]; - for (int rb = 0; rb < PARALLEL_SIZE; rb++) { - px[rb] = x_row[rb] + i * 32 * 16; - } + for (int j = 0; j < 2; j++) { + int k = j * 16; - for (int j = 0; j < 32; j++) { - // 加载 y 数据(对所有行共享) - const int8x16_t yq8_0 = vld1q_s8(y + i * 32 * 64 + j * 64 + 0); - const int8x16_t yq8_1 = vld1q_s8(y + i * 32 * 64 + j * 64 + 16); - const int8x16_t yq8_2 = vld1q_s8(y + i * 32 * 64 + j * 64 + 32); - const int8x16_t yq8_3 = vld1q_s8(y + i * 32 * 64 + j * 64 + 48); + // Load Y data once (shared across all parallel X rows) + int8x16_t y0 = vld1q_s8(py + k + 0*32); + int8x16_t y1 = vld1q_s8(py + k + 1*32); + int8x16_t y2 = vld1q_s8(py + k + 2*32); + int8x16_t y3 = vld1q_s8(py + k + 3*32); - // 处理每一行 for (int rb = 0; rb < PARALLEL_SIZE; rb++) { - uint8x16_t xq8_3 = vld1q_u8(px[rb] + 0); - uint8x16_t xq8_2 = vshrq_n_u8(xq8_3, 2); - uint8x16_t xq8_1 = vshrq_n_u8(xq8_3, 4); - uint8x16_t xq8_0 = vshrq_n_u8(xq8_3, 6); - - int8x16_t q8_3 = vreinterpretq_s8_u8(vandq_u8(xq8_3, mask)); - int8x16_t q8_2 = vreinterpretq_s8_u8(vandq_u8(xq8_2, mask)); - int8x16_t q8_1 = vreinterpretq_s8_u8(vandq_u8(xq8_1, mask)); - int8x16_t q8_0 = vreinterpretq_s8_u8(vandq_u8(xq8_0, mask)); - -#if defined(__ARM_FEATURE_DOTPROD) - accu[rb] = vdotq_s32(accu[rb], q8_0, yq8_0); - accu[rb] = vdotq_s32(accu[rb], q8_1, yq8_1); - accu[rb] = vdotq_s32(accu[rb], q8_2, yq8_2); - accu[rb] = vdotq_s32(accu[rb], q8_3, yq8_3); -#else - accu32[rb] = vmlal_s8(accu32[rb], vget_low_s8(q8_3), vget_low_s8(yq8_3)); - accu32[rb] = vmlal_s8(accu32[rb], vget_high_s8(q8_3), vget_high_s8(yq8_3)); - accu32[rb] = vmlal_s8(accu32[rb], vget_low_s8(q8_2), vget_low_s8(yq8_2)); - accu32[rb] = vmlal_s8(accu32[rb], vget_high_s8(q8_2), vget_high_s8(yq8_2)); - accu32[rb] = vmlal_s8(accu32[rb], vget_low_s8(q8_1), vget_low_s8(yq8_1)); - accu32[rb] = vmlal_s8(accu32[rb], vget_high_s8(q8_1), vget_high_s8(yq8_1)); - accu32[rb] = vmlal_s8(accu32[rb], vget_low_s8(q8_0), vget_low_s8(yq8_0)); - accu32[rb] = vmlal_s8(accu32[rb], vget_high_s8(q8_0), vget_high_s8(yq8_0)); - -#endif - px[rb] += 16; - } - } - -#if defined(__ARM_FEATURE_DOTPROD) - -#else - for (int rb = 0; rb < PARALLEL_SIZE; rb++) { - accu[rb] = vaddq_s32(accu[rb], vmovl_s16(vget_low_s16(accu32[rb]))); - accu[rb] = vaddq_s32(accu[rb], vmovl_high_s16(accu32[rb])); - } -#endif - } + const uint8_t * px = x_row[rb] + b * 32; + uint8x16_t xb = vld1q_u8(px + k); - for (int i = 0; i < groupla_num; i++) { -#if defined(__ARM_FEATURE_DOTPROD) - -#else - int16x8_t accula[PARALLEL_SIZE]; - for (int rb = 0; rb < PARALLEL_SIZE; rb++) { - accula[rb] = vdupq_n_s16(0); - } -#endif - const uint8_t * px[PARALLEL_SIZE]; - for (int rb = 0; rb < PARALLEL_SIZE; rb++) { - px[rb] = x_row[rb] + group32_num * 32 * 16; - } + // Unpack 2-bit values from MSB to LSB (Matching AVX2 layout) + int8x16_t v0 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(xb, 6), mask)); + int8x16_t v1 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(xb, 4), mask)); + int8x16_t v2 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(xb, 2), mask)); + int8x16_t v3 = vreinterpretq_s8_u8(vandq_u8(xb, mask)); - for (int j = 0; j < la_num; j++) { - // 加载 y 数据(对所有行共享) - const int8x16_t yq8_0 = vld1q_s8(y + group32_num * 32 * 64 + j * 64 + 0); - const int8x16_t yq8_1 = vld1q_s8(y + group32_num * 32 * 64 + j * 64 + 16); - const int8x16_t yq8_2 = vld1q_s8(y + group32_num * 32 * 64 + j * 64 + 32); - const int8x16_t yq8_3 = vld1q_s8(y + group32_num * 32 * 64 + j * 64 + 48); - - // 处理每一行 - for (int rb = 0; rb < PARALLEL_SIZE; rb++) { - uint8x16_t xq8_3 = vld1q_u8(px[rb] + 0); - uint8x16_t xq8_2 = vshrq_n_u8(xq8_3, 2); - uint8x16_t xq8_1 = vshrq_n_u8(xq8_3, 4); - uint8x16_t xq8_0 = vshrq_n_u8(xq8_3, 6); - - int8x16_t q8_3 = vreinterpretq_s8_u8(vandq_u8(xq8_3, mask)); - int8x16_t q8_2 = vreinterpretq_s8_u8(vandq_u8(xq8_2, mask)); - int8x16_t q8_1 = vreinterpretq_s8_u8(vandq_u8(xq8_1, mask)); - int8x16_t q8_0 = vreinterpretq_s8_u8(vandq_u8(xq8_0, mask)); - #if defined(__ARM_FEATURE_DOTPROD) - accu[rb] = vdotq_s32(accu[rb], q8_0, yq8_0); - accu[rb] = vdotq_s32(accu[rb], q8_1, yq8_1); - accu[rb] = vdotq_s32(accu[rb], q8_2, yq8_2); - accu[rb] = vdotq_s32(accu[rb], q8_3, yq8_3); + accu[rb] = vdotq_s32(accu[rb], v0, y0); + accu[rb] = vdotq_s32(accu[rb], v1, y1); + accu[rb] = vdotq_s32(accu[rb], v2, y2); + accu[rb] = vdotq_s32(accu[rb], v3, y3); #else - accula[rb] = vmlal_s8(accula[rb], vget_low_s8(q8_3), vget_low_s8(yq8_3)); - accula[rb] = vmlal_s8(accula[rb], vget_high_s8(q8_3), vget_high_s8(yq8_3)); - accula[rb] = vmlal_s8(accula[rb], vget_low_s8(q8_2), vget_low_s8(yq8_2)); - accula[rb] = vmlal_s8(accula[rb], vget_high_s8(q8_2), vget_high_s8(yq8_2)); - accula[rb] = vmlal_s8(accula[rb], vget_low_s8(q8_1), vget_low_s8(yq8_1)); - accula[rb] = vmlal_s8(accula[rb], vget_high_s8(q8_1), vget_high_s8(yq8_1)); - accula[rb] = vmlal_s8(accula[rb], vget_low_s8(q8_0), vget_low_s8(yq8_0)); - accula[rb] = vmlal_s8(accula[rb], vget_high_s8(q8_0), vget_high_s8(yq8_0)); - + int16x8_t accula = vdupq_n_s16(0); + accula = vmlal_s8(accula, vget_low_s8(v0), vget_low_s8(y0)); + accula = vmlal_s8(accula, vget_high_s8(v0), vget_high_s8(y0)); + accula = vmlal_s8(accula, vget_low_s8(v1), vget_low_s8(y1)); + accula = vmlal_s8(accula, vget_high_s8(v1), vget_high_s8(y1)); + accula = vmlal_s8(accula, vget_low_s8(v2), vget_low_s8(y2)); + accula = vmlal_s8(accula, vget_high_s8(v2), vget_high_s8(y2)); + accula = vmlal_s8(accula, vget_low_s8(v3), vget_low_s8(y3)); + accula = vmlal_s8(accula, vget_high_s8(v3), vget_high_s8(y3)); + + accu[rb] = vaddq_s32(accu[rb], vmovl_s16(vget_low_s16(accula))); + accu[rb] = vaddq_s32(accu[rb], vmovl_high_s16(accula)); #endif - px[rb] += 16; } } - -#if defined(__ARM_FEATURE_DOTPROD) - -#else - for (int rb = 0; rb < PARALLEL_SIZE; rb++) { - accu[rb] = vaddq_s32(accu[rb], vmovl_s16(vget_low_s16(accula[rb]))); - accu[rb] = vaddq_s32(accu[rb], vmovl_high_s16(accula[rb])); - } -#endif } - // 合并结果并写回 + // Horizontal sum and write back for (int rb = 0; rb < PARALLEL_SIZE; rb++) { - int sumi = vaddlvq_s32(accu[rb]); + int32_t sumi = vaddvq_s32(accu[rb]); s[row + rb] = (float)sumi; } } @@ -886,153 +817,73 @@ void ggml_vec_dot_i2_i8_s_Nx1(int n, float * s, size_t bs, const void * vx, size } } #elif defined(__ARM_NEON) - const uint8_t * x = (uint8_t *)vx; - const int8_t * y = (int8_t *)vy; + // ==================================================================== + // [Mobile Environment: ARM NEON / DotProd] - Nx1 Parallel Kernel + // Processes 1 common row of X against PARALLEL_SIZE columns/rows of Y. + // ==================================================================== + const uint8_t * x = (const uint8_t *)vx; + const int8_t * y = (const int8_t *)vy; - const int nb = n / QK_I2_S; - const int group32_num = nb / 32; - const int la_num = nb % 32; - const int groupla_num = nb % 32 != 0 ? 1 : 0; + // [Core Fix] Enforce QK=128 to match the standard GGUF memory layout. + const int QK = 128; + const int nb = n / QK; - const uint8x16_t mask = vdupq_n_u8(3); + const uint8x16_t mask = vdupq_n_u8(0x03); for (int col = 0; col < nrc; col += PARALLEL_SIZE) { int32x4_t accu[PARALLEL_SIZE]; - for (int iy = 0; iy < PARALLEL_SIZE; iy++) { accu[iy] = vdupq_n_s32(0); } - const int8_t * y_col = y + col * by; - - for (int i = 0; i < group32_num; i++) { - const uint8_t *px = x + i * 512; // i * 32 * 16 - const int8_t *py = y_col + i * 2048; // i * 32 * 64 + for (int b = 0; b < nb; b++) { + // Load X data once (shared across all parallel Y columns) + const uint8_t * px = x + b * 32; -#if defined(__ARM_FEATURE_DOTPROD) + for (int j = 0; j < 2; j++) { + int k = j * 16; + uint8x16_t xb = vld1q_u8(px + k); -#else - int16x8_t accu32[PARALLEL_SIZE]; + // Unpack 2-bit values from MSB to LSB + int8x16_t v0 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(xb, 6), mask)); + int8x16_t v1 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(xb, 4), mask)); + int8x16_t v2 = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(xb, 2), mask)); + int8x16_t v3 = vreinterpretq_s8_u8(vandq_u8(xb, mask)); - for (int iy = 0; iy < PARALLEL_SIZE; iy++) { - accu32[iy] = vdupq_n_s16(0); - } -#endif - for (int j = 0; j < 32; j++) { - // 加载并解包 x 数据(对所有列共享) - uint8x16_t xq8_3 = vld1q_u8(px + 0); - uint8x16_t xq8_2 = vshrq_n_u8(xq8_3, 2); - uint8x16_t xq8_1 = vshrq_n_u8(xq8_3, 4); - uint8x16_t xq8_0 = vshrq_n_u8(xq8_3, 6); - - int8x16_t q8_0 = vreinterpretq_s8_u8(vandq_u8(xq8_0, mask)); - int8x16_t q8_1 = vreinterpretq_s8_u8(vandq_u8(xq8_1, mask)); - int8x16_t q8_2 = vreinterpretq_s8_u8(vandq_u8(xq8_2, mask)); - int8x16_t q8_3 = vreinterpretq_s8_u8(vandq_u8(xq8_3, mask)); - - // 处理每一列 for (int iy = 0; iy < PARALLEL_SIZE; iy++) { - const int8x16_t yq8_0 = vld1q_s8(py + 0 * 16 + iy * by); - const int8x16_t yq8_1 = vld1q_s8(py + 1 * 16 + iy * by); - const int8x16_t yq8_2 = vld1q_s8(py + 2 * 16 + iy * by); - const int8x16_t yq8_3 = vld1q_s8(py + 3 * 16 + iy * by); + const int8_t * py = y + (col + iy) * by + b * QK; -#if defined(__ARM_FEATURE_DOTPROD) - accu[iy] = vdotq_s32(accu[iy], q8_0, yq8_0); - accu[iy] = vdotq_s32(accu[iy], q8_1, yq8_1); - accu[iy] = vdotq_s32(accu[iy], q8_2, yq8_2); - accu[iy] = vdotq_s32(accu[iy], q8_3, yq8_3); -#else - accu32[iy] = vmlal_s8(accu32[iy], vget_low_s8(q8_0), vget_low_s8(yq8_0)); - accu32[iy] = vmlal_s8(accu32[iy], vget_high_s8(q8_0), vget_high_s8(yq8_0)); - accu32[iy] = vmlal_s8(accu32[iy], vget_low_s8(q8_1), vget_low_s8(yq8_1)); - accu32[iy] = vmlal_s8(accu32[iy], vget_high_s8(q8_1), vget_high_s8(yq8_1)); - accu32[iy] = vmlal_s8(accu32[iy], vget_low_s8(q8_2), vget_low_s8(yq8_2)); - accu32[iy] = vmlal_s8(accu32[iy], vget_high_s8(q8_2), vget_high_s8(yq8_2)); - accu32[iy] = vmlal_s8(accu32[iy], vget_low_s8(q8_3), vget_low_s8(yq8_3)); - accu32[iy] = vmlal_s8(accu32[iy], vget_high_s8(q8_3), vget_high_s8(yq8_3)); -#endif - } - - px += 16; - py += 64; - } - -#if defined(__ARM_FEATURE_DOTPROD) - -#else - for (int iy = 0; iy < PARALLEL_SIZE; iy++) { - accu[iy] = vaddq_s32(accu[iy], vaddq_s32(vmovl_high_s16(accu32[iy]), vmovl_s16(vget_low_s16(accu32[iy])))); - } -#endif - } - - for (int i = 0; i < groupla_num; i++) { - const uint8_t *px = x + group32_num * 512; - const int8_t *py = y_col + group32_num * 2048; - -#if defined(__ARM_FEATURE_DOTPROD) - -#else - int16x8_t accula[PARALLEL_SIZE]; - - for (int iy = 0; iy < PARALLEL_SIZE; iy++) { - accula[iy] = vdupq_n_s16(0); - } -#endif - - for (int j = 0; j < la_num; j++) { - // 加载并解包 x 数据(对所有列共享) - uint8x16_t xq8_3 = vld1q_u8(px + 0); - uint8x16_t xq8_2 = vshrq_n_u8(xq8_3, 2); - uint8x16_t xq8_1 = vshrq_n_u8(xq8_3, 4); - uint8x16_t xq8_0 = vshrq_n_u8(xq8_3, 6); - - int8x16_t q8_0 = vreinterpretq_s8_u8(vandq_u8(xq8_0, mask)); - int8x16_t q8_1 = vreinterpretq_s8_u8(vandq_u8(xq8_1, mask)); - int8x16_t q8_2 = vreinterpretq_s8_u8(vandq_u8(xq8_2, mask)); - int8x16_t q8_3 = vreinterpretq_s8_u8(vandq_u8(xq8_3, mask)); - - // 处理每一列 - for (int iy = 0; iy < PARALLEL_SIZE; iy++) { - const int8x16_t yq8_0 = vld1q_s8(py + 0 * 16 + iy * by); - const int8x16_t yq8_1 = vld1q_s8(py + 1 * 16 + iy * by); - const int8x16_t yq8_2 = vld1q_s8(py + 2 * 16 + iy * by); - const int8x16_t yq8_3 = vld1q_s8(py + 3 * 16 + iy * by); + int8x16_t y0 = vld1q_s8(py + k + 0*32); + int8x16_t y1 = vld1q_s8(py + k + 1*32); + int8x16_t y2 = vld1q_s8(py + k + 2*32); + int8x16_t y3 = vld1q_s8(py + k + 3*32); #if defined(__ARM_FEATURE_DOTPROD) - accu[iy] = vdotq_s32(accu[iy], q8_0, yq8_0); - accu[iy] = vdotq_s32(accu[iy], q8_1, yq8_1); - accu[iy] = vdotq_s32(accu[iy], q8_2, yq8_2); - accu[iy] = vdotq_s32(accu[iy], q8_3, yq8_3); + accu[iy] = vdotq_s32(accu[iy], v0, y0); + accu[iy] = vdotq_s32(accu[iy], v1, y1); + accu[iy] = vdotq_s32(accu[iy], v2, y2); + accu[iy] = vdotq_s32(accu[iy], v3, y3); #else - accula[iy] = vmlal_s8(accula[iy], vget_low_s8(q8_0), vget_low_s8(yq8_0)); - accula[iy] = vmlal_s8(accula[iy], vget_high_s8(q8_0), vget_high_s8(yq8_0)); - accula[iy] = vmlal_s8(accula[iy], vget_low_s8(q8_1), vget_low_s8(yq8_1)); - accula[iy] = vmlal_s8(accula[iy], vget_high_s8(q8_1), vget_high_s8(yq8_1)); - accula[iy] = vmlal_s8(accula[iy], vget_low_s8(q8_2), vget_low_s8(yq8_2)); - accula[iy] = vmlal_s8(accula[iy], vget_high_s8(q8_2), vget_high_s8(yq8_2)); - accula[iy] = vmlal_s8(accula[iy], vget_low_s8(q8_3), vget_low_s8(yq8_3)); - accula[iy] = vmlal_s8(accula[iy], vget_high_s8(q8_3), vget_high_s8(yq8_3)); + int16x8_t accula = vdupq_n_s16(0); + accula = vmlal_s8(accula, vget_low_s8(v0), vget_low_s8(y0)); + accula = vmlal_s8(accula, vget_high_s8(v0), vget_high_s8(y0)); + accula = vmlal_s8(accula, vget_low_s8(v1), vget_low_s8(y1)); + accula = vmlal_s8(accula, vget_high_s8(v1), vget_high_s8(y1)); + accula = vmlal_s8(accula, vget_low_s8(v2), vget_low_s8(y2)); + accula = vmlal_s8(accula, vget_high_s8(v2), vget_high_s8(y2)); + accula = vmlal_s8(accula, vget_low_s8(v3), vget_low_s8(y3)); + accula = vmlal_s8(accula, vget_high_s8(v3), vget_high_s8(y3)); + + accu[iy] = vaddq_s32(accu[iy], vmovl_s16(vget_low_s16(accula))); + accu[iy] = vaddq_s32(accu[iy], vmovl_high_s16(accula)); #endif } - - px += 16; - py += 64; - } - -#if defined(__ARM_FEATURE_DOTPROD) - -#else - for (int iy = 0; iy < PARALLEL_SIZE; iy++) { - accu[iy] = vaddq_s32(accu[iy], vaddq_s32(vmovl_high_s16(accula[iy]), vmovl_s16(vget_low_s16(accula[iy])))); } -#endif } - // 合并结果并写回 + // Horizontal sum and write back for (int iy = 0; iy < PARALLEL_SIZE; iy++) { - int sumi = vaddlvq_s32(accu[iy]); + int32_t sumi = vaddvq_s32(accu[iy]); s[(col + iy) * bs] = (float)sumi; } } @@ -1041,6 +892,11 @@ void ggml_vec_dot_i2_i8_s_Nx1(int n, float * s, size_t bs, const void * vx, size void ggml_vec_dot_i2_i8_s(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { +#if defined(__AVX2__) || defined(__ARM_NEON) + // ==================================================================== + // HW Acceleration Path (AVX2 & ARM NEON) + // Routes to highly optimized parallel kernels if 'nrc' aligns with PARALLEL_SIZE. + // ==================================================================== if (nrc % PARALLEL_SIZE == 0) { #if defined(ACT_PARALLEL) @@ -1051,6 +907,14 @@ void ggml_vec_dot_i2_i8_s(int n, float * s, size_t bs, const void * vx, size_t b } else { + // Fallback to 1x1 processing for remainder rows/cols ggml_vec_dot_i2_i8_s_1x1(n, s, bs, vx, bx, vy, by, nrc); } -} \ No newline at end of file +#else + // ==================================================================== + // Pure Scalar Fallback Path + // Executed only when hardware acceleration is disabled or unsupported. + // ==================================================================== + ggml_vec_dot_i2_i8_s_1x1(n, s, bs, vx, bx, vy, by, nrc); +#endif +}