commit 501a43674b8aa5675c059c08ff3584264aec4d05
parent 4615e6a5d18b3ee05831fab834628d51c011a6d9
Author: finwo <finwo@pm.me>
Date: Sat, 18 Apr 2026 19:41:54 +0200
avx512-vnni optimizations
Diffstat:
| M | src/matmul.c | | | 114 | ++++++++++++++++++++++++++++++++++++------------------------------------------- |
1 file changed, 52 insertions(+), 62 deletions(-)
diff --git a/src/matmul.c b/src/matmul.c
@@ -155,79 +155,69 @@ int matmul_scalar_u8_i8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const
}
#ifdef __AVX512VNNI__
+static void pack_b_i8(size_t n, size_t p, const int8_t *B, int8_t *B_packed) {
+ size_t n4 = n / 4;
+ size_t p16 = p / 16;
+ for (size_t j16 = 0; j16 < p16; j16++) {
+ for (size_t k4 = 0; k4 < n4; k4++) {
+ int8_t *dst = &B_packed[(j16 * n4 + k4) * 64];
+ for (size_t dj = 0; dj < 16; dj++) {
+ size_t j = j16 * 16 + dj;
+ for (size_t dk = 0; dk < 4; dk++) {
+ dst[dj * 4 + dk] = B[(k4 * 4 + dk) * p + j];
+ }
+ }
+ }
+ }
+}
+
int matmul_avx512vnni_u8_i8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, uint8_t *C,
double scale) {
(void)scale;
- const size_t ib = 64;
- const size_t jb = 64;
- const size_t kb = 32;
-#pragma omp parallel for schedule(static)
- for (size_t ii = 0; ii < m; ii += ib) {
- size_t i_end = (ii + ib < m) ? ii + ib : m;
- for (size_t jj = 0; jj < p; jj += jb) {
- size_t j_end = (jj + jb < p) ? jj + jb : p;
- size_t ti = i_end - ii;
- size_t tj = j_end - jj;
- int32_t acc[64 * 64];
- memset(acc, 0, ti * tj * sizeof(int32_t));
+ size_t n4 = n / 4;
+ size_t p16 = p / 16;
- for (size_t kk = 0; kk < n; kk += kb) {
- size_t k_limit = (kk + kb < n) ? kk + kb : n;
+ int8_t *B_packed;
+ if (posix_memalign((void **)&B_packed, 64, p16 * n4 * 64) != 0) return -1;
+ pack_b_i8(n, p, B, B_packed);
- for (size_t i = ii; i < i_end; i++) {
- size_t li = i - ii;
- int32_t *acc_row = &acc[li * tj];
-
- for (size_t j = jj; j < j_end; j += 16) {
- size_t j_chunk = (j + 16 <= j_end) ? 16 : (j_end - j);
- __m512i result = _mm512_setzero_si512();
-
- for (size_t k = kk; k < k_limit; k += 4) {
- size_t k_chunk = (k + 4 <= k_limit) ? 4 : (k_limit - k);
-
- uint32_t a4 = 0;
- for (size_t dk = 0; dk < k_chunk; dk++) {
- a4 |= (uint32_t)A[i * n + k + dk] << (dk * 8);
- }
- __m512i a_val = _mm512_set1_epi32(a4);
-
- int8_t b_buf[64] = {0};
- for (size_t dj = 0; dj < j_chunk; dj++) {
- for (size_t dk = 0; dk < k_chunk; dk++) {
- b_buf[dj * 4 + dk] = B[(k + dk) * p + j + dj];
- }
- }
- __m512i b_val = _mm512_load_si512((__m512i const *)b_buf);
-
- result = _mm512_dpbusd_epi32(result, a_val, b_val);
- }
-
- int32_t tmp[16] __attribute__((aligned(64)));
- _mm512_store_si512(tmp, result);
+ const uint32_t *A32 = (const uint32_t *)A;
- size_t j_offset = j - jj;
- for (size_t c = 0; c < j_chunk; c++) {
- acc_row[j_offset + c] += tmp[c];
- }
- }
- }
+#pragma omp parallel for schedule(static)
+ for (size_t i = 0; i < m; i++) {
+ for (size_t j16 = 0; j16 < p16; j16++) {
+ __m512i result = _mm512_setzero_si512();
+ for (size_t k4 = 0; k4 < n4; k4++) {
+ __m512i a_val = _mm512_set1_epi32(A32[i * n4 + k4]);
+ __m512i b_val = _mm512_load_si512((__m512i const *)&B_packed[(j16 * n4 + k4) * 64]);
+ result = _mm512_dpbusd_epi32(result, a_val, b_val);
}
-
- for (size_t i = ii; i < i_end; i++) {
- size_t li = i - ii;
- for (size_t j = jj; j < j_end; j++) {
- size_t lj = j - jj;
- int32_t v = acc[li * tj + lj];
- if (v > 255)
- v = 255;
- else if (v < 0)
- v = 0;
- C[i * p + j] = (uint8_t)v;
- }
+ int32_t tmp[16] __attribute__((aligned(64)));
+ _mm512_store_si512(tmp, result);
+ for (size_t dj = 0; dj < 16; dj++) {
+ int32_t v = tmp[dj];
+ if (v > 255)
+ v = 255;
+ else if (v < 0)
+ v = 0;
+ C[i * p + j16 * 16 + dj] = (uint8_t)v;
+ }
+ }
+ for (size_t j = p16 * 16; j < p; j++) {
+ int32_t sum = 0;
+ for (size_t k = 0; k < n; k++) {
+ sum += (int)A[i * n + k] * (int)B[k * p + j];
}
+ if (sum > 255)
+ sum = 255;
+ else if (sum < 0)
+ sum = 0;
+ C[i * p + j] = (uint8_t)sum;
}
}
+
+ free(B_packed);
return 0;
}
#endif