matmul.c

Matrix multiplication helper library
git clone git://git.finwo.net/lib/matmul.c
Log | Files | Refs | README | LICENSE

commit 560fa4c4e737269b1fc591e34d483f3385c56934
parent afc408a3afb47a47d58034b8f0ce806b5ad47c0a
Author: finwo <finwo@pm.me>
Date:   Mon, 20 Apr 2026 02:15:47 +0200

Unroll by 4 for more performance

Diffstat:
Msrc/matmul.c | 172++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-----------------
1 file changed, 136 insertions(+), 36 deletions(-)

diff --git a/src/matmul.c b/src/matmul.c @@ -331,22 +331,47 @@ int matmul_avx2_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const memset(acc, 0, ti * tj * sizeof(float)); for (size_t kk = 0; kk < n; kk += kb) { - size_t k_end = (kk + kb < n) ? kk + kb : n; + size_t k_end = (kk + kb < n) ? kk + kb : n; + size_t k_end8 = kk + (k_end - kk) / 8 * 8; for (size_t i = ii; i < i_end; i++) { size_t li = i - ii; for (size_t j = jj; j + 8 <= j_end; j += 8) { - size_t lj = j - jj; - size_t j8idx = j / 8; - __m256 acc_vec = _mm256_setzero_ps(); - for (size_t k = kk; k < k_end; k++) { - size_t k8 = k / 8; - size_t dk = k % 8; + size_t lj = j - jj; + size_t j8idx = j / 8; + __m256 acc0 = _mm256_setzero_ps(); + __m256 acc1 = _mm256_setzero_ps(); + __m256 acc2 = _mm256_setzero_ps(); + __m256 acc3 = _mm256_setzero_ps(); + size_t k = kk; + for (; k + 4 <= k_end8; k += 4) { + size_t k8_0 = k / 8, dk_0 = k % 8; + size_t k8_1 = (k + 1) / 8, dk_1 = (k + 1) % 8; + size_t k8_2 = (k + 2) / 8, dk_2 = (k + 2) % 8; + size_t k8_3 = (k + 3) / 8, dk_3 = (k + 3) % 8; + __m256 a0 = _mm256_set1_ps(A[i * n + k]); + __m256 a1 = _mm256_set1_ps(A[i * n + k + 1]); + __m256 a2 = _mm256_set1_ps(A[i * n + k + 2]); + __m256 a3 = _mm256_set1_ps(A[i * n + k + 3]); + __m256 b0 = _mm256_load_ps(&B_packed[(j8idx * n8 + k8_0) * 64 + dk_0 * 8]); + __m256 b1 = _mm256_load_ps(&B_packed[(j8idx * n8 + k8_1) * 64 + dk_1 * 8]); + __m256 b2 = _mm256_load_ps(&B_packed[(j8idx * n8 + k8_2) * 64 + dk_2 * 8]); + __m256 b3 = _mm256_load_ps(&B_packed[(j8idx * n8 + k8_3) * 64 + dk_3 * 8]); + acc0 = _mm256_fmadd_ps(a0, b0, acc0); + acc1 = _mm256_fmadd_ps(a1, b1, acc1); + acc2 = _mm256_fmadd_ps(a2, b2, acc2); + acc3 = _mm256_fmadd_ps(a3, b3, acc3); + } + acc0 = _mm256_add_ps(acc0, acc1); + acc2 = _mm256_add_ps(acc2, acc3); + acc0 = _mm256_add_ps(acc0, acc2); + for (; k < k_end; k++) { + size_t k8 = k / 8, dk = k % 8; __m256 a_bcast = _mm256_set1_ps(A[i * n + k]); __m256 b_val = _mm256_load_ps(&B_packed[(j8idx * n8 + k8) * 64 + dk * 8]); - acc_vec = _mm256_fmadd_ps(a_bcast, b_val, acc_vec); + acc0 = _mm256_fmadd_ps(a_bcast, b_val, acc0); } float tmp[8] __attribute__((aligned(32))); - _mm256_store_ps(tmp, acc_vec); + _mm256_store_ps(tmp, acc0); for (size_t dj = 0; dj < 8; dj++) acc[li * tj + lj + dj] += tmp[dj]; } for (size_t j = jj + (tj / 8) * 8; j < j_end; j++) { @@ -427,22 +452,47 @@ int matmul_avx512_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, cons memset(acc, 0, ti * tj * sizeof(float)); for (size_t kk = 0; kk < n; kk += kb) { - size_t k_end = (kk + kb < n) ? kk + kb : n; + size_t k_end = (kk + kb < n) ? kk + kb : n; + size_t k_end16 = kk + (k_end - kk) / 16 * 16; for (size_t i = ii; i < i_end; i++) { size_t li = i - ii; for (size_t j = jj; j + 16 <= j_end; j += 16) { - size_t lj = j - jj; - size_t j16idx = j / 16; - __m512 acc_vec = _mm512_setzero_ps(); - for (size_t k = kk; k < k_end; k++) { - size_t k16 = k / 16; - size_t dk = k % 16; + size_t lj = j - jj; + size_t j16idx = j / 16; + __m512 acc0 = _mm512_setzero_ps(); + __m512 acc1 = _mm512_setzero_ps(); + __m512 acc2 = _mm512_setzero_ps(); + __m512 acc3 = _mm512_setzero_ps(); + size_t k = kk; + for (; k + 4 <= k_end16; k += 4) { + size_t k16_0 = k / 16, dk_0 = k % 16; + size_t k16_1 = (k + 1) / 16, dk_1 = (k + 1) % 16; + size_t k16_2 = (k + 2) / 16, dk_2 = (k + 2) % 16; + size_t k16_3 = (k + 3) / 16, dk_3 = (k + 3) % 16; + __m512 a0 = _mm512_set1_ps(A[i * n + k]); + __m512 a1 = _mm512_set1_ps(A[i * n + k + 1]); + __m512 a2 = _mm512_set1_ps(A[i * n + k + 2]); + __m512 a3 = _mm512_set1_ps(A[i * n + k + 3]); + __m512 b0 = _mm512_load_ps(&B_packed[(j16idx * n16 + k16_0) * 256 + dk_0 * 16]); + __m512 b1 = _mm512_load_ps(&B_packed[(j16idx * n16 + k16_1) * 256 + dk_1 * 16]); + __m512 b2 = _mm512_load_ps(&B_packed[(j16idx * n16 + k16_2) * 256 + dk_2 * 16]); + __m512 b3 = _mm512_load_ps(&B_packed[(j16idx * n16 + k16_3) * 256 + dk_3 * 16]); + acc0 = _mm512_fmadd_ps(a0, b0, acc0); + acc1 = _mm512_fmadd_ps(a1, b1, acc1); + acc2 = _mm512_fmadd_ps(a2, b2, acc2); + acc3 = _mm512_fmadd_ps(a3, b3, acc3); + } + acc0 = _mm512_add_ps(acc0, acc1); + acc2 = _mm512_add_ps(acc2, acc3); + acc0 = _mm512_add_ps(acc0, acc2); + for (; k < k_end; k++) { + size_t k16 = k / 16, dk = k % 16; __m512 a_bcast = _mm512_set1_ps(A[i * n + k]); __m512 b_val = _mm512_load_ps(&B_packed[(j16idx * n16 + k16) * 256 + dk * 16]); - acc_vec = _mm512_fmadd_ps(a_bcast, b_val, acc_vec); + acc0 = _mm512_fmadd_ps(a_bcast, b_val, acc0); } float tmp[16] __attribute__((aligned(64))); - _mm512_store_ps(tmp, acc_vec); + _mm512_store_ps(tmp, acc0); for (size_t dj = 0; dj < 16; dj++) acc[li * tj + lj + dj] += tmp[dj]; } for (size_t j = jj + (tj / 16) * 16; j < j_end; j++) { @@ -593,22 +643,47 @@ int matmul_avx2_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, const memset(acc, 0, ti * tj * sizeof(double)); for (size_t kk = 0; kk < n; kk += kb) { - size_t k_end = (kk + kb < n) ? kk + kb : n; + size_t k_end = (kk + kb < n) ? kk + kb : n; + size_t k_end4 = kk + (k_end - kk) / 4 * 4; for (size_t i = ii; i < i_end; i++) { size_t li = i - ii; for (size_t j = jj; j + 4 <= j_end; j += 4) { - size_t lj = j - jj; - size_t j4idx = j / 4; - __m256d acc_vec = _mm256_setzero_pd(); - for (size_t k = kk; k < k_end; k++) { - size_t k4 = k / 4; - size_t dk = k % 4; + size_t lj = j - jj; + size_t j4idx = j / 4; + __m256d acc0 = _mm256_setzero_pd(); + __m256d acc1 = _mm256_setzero_pd(); + __m256d acc2 = _mm256_setzero_pd(); + __m256d acc3 = _mm256_setzero_pd(); + size_t k = kk; + for (; k + 4 <= k_end4; k += 4) { + size_t k4_0 = k / 4, dk_0 = k % 4; + size_t k4_1 = (k + 1) / 4, dk_1 = (k + 1) % 4; + size_t k4_2 = (k + 2) / 4, dk_2 = (k + 2) % 4; + size_t k4_3 = (k + 3) / 4, dk_3 = (k + 3) % 4; + __m256d a0 = _mm256_set1_pd(A[i * n + k]); + __m256d a1 = _mm256_set1_pd(A[i * n + k + 1]); + __m256d a2 = _mm256_set1_pd(A[i * n + k + 2]); + __m256d a3 = _mm256_set1_pd(A[i * n + k + 3]); + __m256d b0 = _mm256_load_pd(&B_packed[(j4idx * n4 + k4_0) * 16 + dk_0 * 4]); + __m256d b1 = _mm256_load_pd(&B_packed[(j4idx * n4 + k4_1) * 16 + dk_1 * 4]); + __m256d b2 = _mm256_load_pd(&B_packed[(j4idx * n4 + k4_2) * 16 + dk_2 * 4]); + __m256d b3 = _mm256_load_pd(&B_packed[(j4idx * n4 + k4_3) * 16 + dk_3 * 4]); + acc0 = _mm256_fmadd_pd(a0, b0, acc0); + acc1 = _mm256_fmadd_pd(a1, b1, acc1); + acc2 = _mm256_fmadd_pd(a2, b2, acc2); + acc3 = _mm256_fmadd_pd(a3, b3, acc3); + } + acc0 = _mm256_add_pd(acc0, acc1); + acc2 = _mm256_add_pd(acc2, acc3); + acc0 = _mm256_add_pd(acc0, acc2); + for (; k < k_end; k++) { + size_t k4 = k / 4, dk = k % 4; __m256d a_bcast = _mm256_set1_pd(A[i * n + k]); __m256d b_val = _mm256_load_pd(&B_packed[(j4idx * n4 + k4) * 16 + dk * 4]); - acc_vec = _mm256_fmadd_pd(a_bcast, b_val, acc_vec); + acc0 = _mm256_fmadd_pd(a_bcast, b_val, acc0); } double tmp[4] __attribute__((aligned(32))); - _mm256_store_pd(tmp, acc_vec); + _mm256_store_pd(tmp, acc0); for (size_t dj = 0; dj < 4; dj++) acc[li * tj + lj + dj] += tmp[dj]; } for (size_t j = jj + (tj / 4) * 4; j < j_end; j++) { @@ -689,22 +764,47 @@ int matmul_avx512_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, con memset(acc, 0, ti * tj * sizeof(double)); for (size_t kk = 0; kk < n; kk += kb) { - size_t k_end = (kk + kb < n) ? kk + kb : n; + size_t k_end = (kk + kb < n) ? kk + kb : n; + size_t k_end8 = kk + (k_end - kk) / 8 * 8; for (size_t i = ii; i < i_end; i++) { size_t li = i - ii; for (size_t j = jj; j + 8 <= j_end; j += 8) { - size_t lj = j - jj; - size_t j8idx = j / 8; - __m512d acc_vec = _mm512_setzero_pd(); - for (size_t k = kk; k < k_end; k++) { - size_t k8 = k / 8; - size_t dk = k % 8; + size_t lj = j - jj; + size_t j8idx = j / 8; + __m512d acc0 = _mm512_setzero_pd(); + __m512d acc1 = _mm512_setzero_pd(); + __m512d acc2 = _mm512_setzero_pd(); + __m512d acc3 = _mm512_setzero_pd(); + size_t k = kk; + for (; k + 4 <= k_end8; k += 4) { + size_t k8_0 = k / 8, dk_0 = k % 8; + size_t k8_1 = (k + 1) / 8, dk_1 = (k + 1) % 8; + size_t k8_2 = (k + 2) / 8, dk_2 = (k + 2) % 8; + size_t k8_3 = (k + 3) / 8, dk_3 = (k + 3) % 8; + __m512d a0 = _mm512_set1_pd(A[i * n + k]); + __m512d a1 = _mm512_set1_pd(A[i * n + k + 1]); + __m512d a2 = _mm512_set1_pd(A[i * n + k + 2]); + __m512d a3 = _mm512_set1_pd(A[i * n + k + 3]); + __m512d b0 = _mm512_load_pd(&B_packed[(j8idx * n8 + k8_0) * 64 + dk_0 * 8]); + __m512d b1 = _mm512_load_pd(&B_packed[(j8idx * n8 + k8_1) * 64 + dk_1 * 8]); + __m512d b2 = _mm512_load_pd(&B_packed[(j8idx * n8 + k8_2) * 64 + dk_2 * 8]); + __m512d b3 = _mm512_load_pd(&B_packed[(j8idx * n8 + k8_3) * 64 + dk_3 * 8]); + acc0 = _mm512_fmadd_pd(a0, b0, acc0); + acc1 = _mm512_fmadd_pd(a1, b1, acc1); + acc2 = _mm512_fmadd_pd(a2, b2, acc2); + acc3 = _mm512_fmadd_pd(a3, b3, acc3); + } + acc0 = _mm512_add_pd(acc0, acc1); + acc2 = _mm512_add_pd(acc2, acc3); + acc0 = _mm512_add_pd(acc0, acc2); + for (; k < k_end; k++) { + size_t k8 = k / 8, dk = k % 8; __m512d a_bcast = _mm512_set1_pd(A[i * n + k]); __m512d b_val = _mm512_load_pd(&B_packed[(j8idx * n8 + k8) * 64 + dk * 8]); - acc_vec = _mm512_fmadd_pd(a_bcast, b_val, acc_vec); + acc0 = _mm512_fmadd_pd(a_bcast, b_val, acc0); } double tmp[8] __attribute__((aligned(64))); - _mm512_store_pd(tmp, acc_vec); + _mm512_store_pd(tmp, acc0); for (size_t dj = 0; dj < 8; dj++) acc[li * tj + lj + dj] += tmp[dj]; } for (size_t j = jj + (tj / 8) * 8; j < j_end; j++) {