matmul.c

Matrix multiplication helper library
git clone git://git.finwo.net/lib/matmul.c
Log | Files | Refs | README | LICENSE

commit 501a43674b8aa5675c059c08ff3584264aec4d05
parent 4615e6a5d18b3ee05831fab834628d51c011a6d9
Author: finwo <finwo@pm.me>
Date:   Sat, 18 Apr 2026 19:41:54 +0200

avx512-vnni optimizations

Diffstat:
Msrc/matmul.c | 114++++++++++++++++++++++++++++++++++++-------------------------------------------
1 file changed, 52 insertions(+), 62 deletions(-)

diff --git a/src/matmul.c b/src/matmul.c @@ -155,79 +155,69 @@ int matmul_scalar_u8_i8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const } #ifdef __AVX512VNNI__ +static void pack_b_i8(size_t n, size_t p, const int8_t *B, int8_t *B_packed) { + size_t n4 = n / 4; + size_t p16 = p / 16; + for (size_t j16 = 0; j16 < p16; j16++) { + for (size_t k4 = 0; k4 < n4; k4++) { + int8_t *dst = &B_packed[(j16 * n4 + k4) * 64]; + for (size_t dj = 0; dj < 16; dj++) { + size_t j = j16 * 16 + dj; + for (size_t dk = 0; dk < 4; dk++) { + dst[dj * 4 + dk] = B[(k4 * 4 + dk) * p + j]; + } + } + } + } +} + int matmul_avx512vnni_u8_i8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, uint8_t *C, double scale) { (void)scale; - const size_t ib = 64; - const size_t jb = 64; - const size_t kb = 32; -#pragma omp parallel for schedule(static) - for (size_t ii = 0; ii < m; ii += ib) { - size_t i_end = (ii + ib < m) ? ii + ib : m; - for (size_t jj = 0; jj < p; jj += jb) { - size_t j_end = (jj + jb < p) ? jj + jb : p; - size_t ti = i_end - ii; - size_t tj = j_end - jj; - int32_t acc[64 * 64]; - memset(acc, 0, ti * tj * sizeof(int32_t)); + size_t n4 = n / 4; + size_t p16 = p / 16; - for (size_t kk = 0; kk < n; kk += kb) { - size_t k_limit = (kk + kb < n) ? kk + kb : n; + int8_t *B_packed; + if (posix_memalign((void **)&B_packed, 64, p16 * n4 * 64) != 0) return -1; + pack_b_i8(n, p, B, B_packed); - for (size_t i = ii; i < i_end; i++) { - size_t li = i - ii; - int32_t *acc_row = &acc[li * tj]; - - for (size_t j = jj; j < j_end; j += 16) { - size_t j_chunk = (j + 16 <= j_end) ? 16 : (j_end - j); - __m512i result = _mm512_setzero_si512(); - - for (size_t k = kk; k < k_limit; k += 4) { - size_t k_chunk = (k + 4 <= k_limit) ? 4 : (k_limit - k); - - uint32_t a4 = 0; - for (size_t dk = 0; dk < k_chunk; dk++) { - a4 |= (uint32_t)A[i * n + k + dk] << (dk * 8); - } - __m512i a_val = _mm512_set1_epi32(a4); - - int8_t b_buf[64] = {0}; - for (size_t dj = 0; dj < j_chunk; dj++) { - for (size_t dk = 0; dk < k_chunk; dk++) { - b_buf[dj * 4 + dk] = B[(k + dk) * p + j + dj]; - } - } - __m512i b_val = _mm512_load_si512((__m512i const *)b_buf); - - result = _mm512_dpbusd_epi32(result, a_val, b_val); - } - - int32_t tmp[16] __attribute__((aligned(64))); - _mm512_store_si512(tmp, result); + const uint32_t *A32 = (const uint32_t *)A; - size_t j_offset = j - jj; - for (size_t c = 0; c < j_chunk; c++) { - acc_row[j_offset + c] += tmp[c]; - } - } - } +#pragma omp parallel for schedule(static) + for (size_t i = 0; i < m; i++) { + for (size_t j16 = 0; j16 < p16; j16++) { + __m512i result = _mm512_setzero_si512(); + for (size_t k4 = 0; k4 < n4; k4++) { + __m512i a_val = _mm512_set1_epi32(A32[i * n4 + k4]); + __m512i b_val = _mm512_load_si512((__m512i const *)&B_packed[(j16 * n4 + k4) * 64]); + result = _mm512_dpbusd_epi32(result, a_val, b_val); } - - for (size_t i = ii; i < i_end; i++) { - size_t li = i - ii; - for (size_t j = jj; j < j_end; j++) { - size_t lj = j - jj; - int32_t v = acc[li * tj + lj]; - if (v > 255) - v = 255; - else if (v < 0) - v = 0; - C[i * p + j] = (uint8_t)v; - } + int32_t tmp[16] __attribute__((aligned(64))); + _mm512_store_si512(tmp, result); + for (size_t dj = 0; dj < 16; dj++) { + int32_t v = tmp[dj]; + if (v > 255) + v = 255; + else if (v < 0) + v = 0; + C[i * p + j16 * 16 + dj] = (uint8_t)v; + } + } + for (size_t j = p16 * 16; j < p; j++) { + int32_t sum = 0; + for (size_t k = 0; k < n; k++) { + sum += (int)A[i * n + k] * (int)B[k * p + j]; } + if (sum > 255) + sum = 255; + else if (sum < 0) + sum = 0; + C[i * p + j] = (uint8_t)sum; } } + + free(B_packed); return 0; } #endif