commit 560fa4c4e737269b1fc591e34d483f3385c56934
parent afc408a3afb47a47d58034b8f0ce806b5ad47c0a
Author: finwo <finwo@pm.me>
Date: Mon, 20 Apr 2026 02:15:47 +0200
Unroll by 4 for more performance
Diffstat:
| M | src/matmul.c | | | 172 | ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++----------------- |
1 file changed, 136 insertions(+), 36 deletions(-)
diff --git a/src/matmul.c b/src/matmul.c
@@ -331,22 +331,47 @@ int matmul_avx2_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const
memset(acc, 0, ti * tj * sizeof(float));
for (size_t kk = 0; kk < n; kk += kb) {
- size_t k_end = (kk + kb < n) ? kk + kb : n;
+ size_t k_end = (kk + kb < n) ? kk + kb : n;
+ size_t k_end8 = kk + (k_end - kk) / 8 * 8;
for (size_t i = ii; i < i_end; i++) {
size_t li = i - ii;
for (size_t j = jj; j + 8 <= j_end; j += 8) {
- size_t lj = j - jj;
- size_t j8idx = j / 8;
- __m256 acc_vec = _mm256_setzero_ps();
- for (size_t k = kk; k < k_end; k++) {
- size_t k8 = k / 8;
- size_t dk = k % 8;
+ size_t lj = j - jj;
+ size_t j8idx = j / 8;
+ __m256 acc0 = _mm256_setzero_ps();
+ __m256 acc1 = _mm256_setzero_ps();
+ __m256 acc2 = _mm256_setzero_ps();
+ __m256 acc3 = _mm256_setzero_ps();
+ size_t k = kk;
+ for (; k + 4 <= k_end8; k += 4) {
+ size_t k8_0 = k / 8, dk_0 = k % 8;
+ size_t k8_1 = (k + 1) / 8, dk_1 = (k + 1) % 8;
+ size_t k8_2 = (k + 2) / 8, dk_2 = (k + 2) % 8;
+ size_t k8_3 = (k + 3) / 8, dk_3 = (k + 3) % 8;
+ __m256 a0 = _mm256_set1_ps(A[i * n + k]);
+ __m256 a1 = _mm256_set1_ps(A[i * n + k + 1]);
+ __m256 a2 = _mm256_set1_ps(A[i * n + k + 2]);
+ __m256 a3 = _mm256_set1_ps(A[i * n + k + 3]);
+ __m256 b0 = _mm256_load_ps(&B_packed[(j8idx * n8 + k8_0) * 64 + dk_0 * 8]);
+ __m256 b1 = _mm256_load_ps(&B_packed[(j8idx * n8 + k8_1) * 64 + dk_1 * 8]);
+ __m256 b2 = _mm256_load_ps(&B_packed[(j8idx * n8 + k8_2) * 64 + dk_2 * 8]);
+ __m256 b3 = _mm256_load_ps(&B_packed[(j8idx * n8 + k8_3) * 64 + dk_3 * 8]);
+ acc0 = _mm256_fmadd_ps(a0, b0, acc0);
+ acc1 = _mm256_fmadd_ps(a1, b1, acc1);
+ acc2 = _mm256_fmadd_ps(a2, b2, acc2);
+ acc3 = _mm256_fmadd_ps(a3, b3, acc3);
+ }
+ acc0 = _mm256_add_ps(acc0, acc1);
+ acc2 = _mm256_add_ps(acc2, acc3);
+ acc0 = _mm256_add_ps(acc0, acc2);
+ for (; k < k_end; k++) {
+ size_t k8 = k / 8, dk = k % 8;
__m256 a_bcast = _mm256_set1_ps(A[i * n + k]);
__m256 b_val = _mm256_load_ps(&B_packed[(j8idx * n8 + k8) * 64 + dk * 8]);
- acc_vec = _mm256_fmadd_ps(a_bcast, b_val, acc_vec);
+ acc0 = _mm256_fmadd_ps(a_bcast, b_val, acc0);
}
float tmp[8] __attribute__((aligned(32)));
- _mm256_store_ps(tmp, acc_vec);
+ _mm256_store_ps(tmp, acc0);
for (size_t dj = 0; dj < 8; dj++) acc[li * tj + lj + dj] += tmp[dj];
}
for (size_t j = jj + (tj / 8) * 8; j < j_end; j++) {
@@ -427,22 +452,47 @@ int matmul_avx512_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, cons
memset(acc, 0, ti * tj * sizeof(float));
for (size_t kk = 0; kk < n; kk += kb) {
- size_t k_end = (kk + kb < n) ? kk + kb : n;
+ size_t k_end = (kk + kb < n) ? kk + kb : n;
+ size_t k_end16 = kk + (k_end - kk) / 16 * 16;
for (size_t i = ii; i < i_end; i++) {
size_t li = i - ii;
for (size_t j = jj; j + 16 <= j_end; j += 16) {
- size_t lj = j - jj;
- size_t j16idx = j / 16;
- __m512 acc_vec = _mm512_setzero_ps();
- for (size_t k = kk; k < k_end; k++) {
- size_t k16 = k / 16;
- size_t dk = k % 16;
+ size_t lj = j - jj;
+ size_t j16idx = j / 16;
+ __m512 acc0 = _mm512_setzero_ps();
+ __m512 acc1 = _mm512_setzero_ps();
+ __m512 acc2 = _mm512_setzero_ps();
+ __m512 acc3 = _mm512_setzero_ps();
+ size_t k = kk;
+ for (; k + 4 <= k_end16; k += 4) {
+ size_t k16_0 = k / 16, dk_0 = k % 16;
+ size_t k16_1 = (k + 1) / 16, dk_1 = (k + 1) % 16;
+ size_t k16_2 = (k + 2) / 16, dk_2 = (k + 2) % 16;
+ size_t k16_3 = (k + 3) / 16, dk_3 = (k + 3) % 16;
+ __m512 a0 = _mm512_set1_ps(A[i * n + k]);
+ __m512 a1 = _mm512_set1_ps(A[i * n + k + 1]);
+ __m512 a2 = _mm512_set1_ps(A[i * n + k + 2]);
+ __m512 a3 = _mm512_set1_ps(A[i * n + k + 3]);
+ __m512 b0 = _mm512_load_ps(&B_packed[(j16idx * n16 + k16_0) * 256 + dk_0 * 16]);
+ __m512 b1 = _mm512_load_ps(&B_packed[(j16idx * n16 + k16_1) * 256 + dk_1 * 16]);
+ __m512 b2 = _mm512_load_ps(&B_packed[(j16idx * n16 + k16_2) * 256 + dk_2 * 16]);
+ __m512 b3 = _mm512_load_ps(&B_packed[(j16idx * n16 + k16_3) * 256 + dk_3 * 16]);
+ acc0 = _mm512_fmadd_ps(a0, b0, acc0);
+ acc1 = _mm512_fmadd_ps(a1, b1, acc1);
+ acc2 = _mm512_fmadd_ps(a2, b2, acc2);
+ acc3 = _mm512_fmadd_ps(a3, b3, acc3);
+ }
+ acc0 = _mm512_add_ps(acc0, acc1);
+ acc2 = _mm512_add_ps(acc2, acc3);
+ acc0 = _mm512_add_ps(acc0, acc2);
+ for (; k < k_end; k++) {
+ size_t k16 = k / 16, dk = k % 16;
__m512 a_bcast = _mm512_set1_ps(A[i * n + k]);
__m512 b_val = _mm512_load_ps(&B_packed[(j16idx * n16 + k16) * 256 + dk * 16]);
- acc_vec = _mm512_fmadd_ps(a_bcast, b_val, acc_vec);
+ acc0 = _mm512_fmadd_ps(a_bcast, b_val, acc0);
}
float tmp[16] __attribute__((aligned(64)));
- _mm512_store_ps(tmp, acc_vec);
+ _mm512_store_ps(tmp, acc0);
for (size_t dj = 0; dj < 16; dj++) acc[li * tj + lj + dj] += tmp[dj];
}
for (size_t j = jj + (tj / 16) * 16; j < j_end; j++) {
@@ -593,22 +643,47 @@ int matmul_avx2_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, const
memset(acc, 0, ti * tj * sizeof(double));
for (size_t kk = 0; kk < n; kk += kb) {
- size_t k_end = (kk + kb < n) ? kk + kb : n;
+ size_t k_end = (kk + kb < n) ? kk + kb : n;
+ size_t k_end4 = kk + (k_end - kk) / 4 * 4;
for (size_t i = ii; i < i_end; i++) {
size_t li = i - ii;
for (size_t j = jj; j + 4 <= j_end; j += 4) {
- size_t lj = j - jj;
- size_t j4idx = j / 4;
- __m256d acc_vec = _mm256_setzero_pd();
- for (size_t k = kk; k < k_end; k++) {
- size_t k4 = k / 4;
- size_t dk = k % 4;
+ size_t lj = j - jj;
+ size_t j4idx = j / 4;
+ __m256d acc0 = _mm256_setzero_pd();
+ __m256d acc1 = _mm256_setzero_pd();
+ __m256d acc2 = _mm256_setzero_pd();
+ __m256d acc3 = _mm256_setzero_pd();
+ size_t k = kk;
+ for (; k + 4 <= k_end4; k += 4) {
+ size_t k4_0 = k / 4, dk_0 = k % 4;
+ size_t k4_1 = (k + 1) / 4, dk_1 = (k + 1) % 4;
+ size_t k4_2 = (k + 2) / 4, dk_2 = (k + 2) % 4;
+ size_t k4_3 = (k + 3) / 4, dk_3 = (k + 3) % 4;
+ __m256d a0 = _mm256_set1_pd(A[i * n + k]);
+ __m256d a1 = _mm256_set1_pd(A[i * n + k + 1]);
+ __m256d a2 = _mm256_set1_pd(A[i * n + k + 2]);
+ __m256d a3 = _mm256_set1_pd(A[i * n + k + 3]);
+ __m256d b0 = _mm256_load_pd(&B_packed[(j4idx * n4 + k4_0) * 16 + dk_0 * 4]);
+ __m256d b1 = _mm256_load_pd(&B_packed[(j4idx * n4 + k4_1) * 16 + dk_1 * 4]);
+ __m256d b2 = _mm256_load_pd(&B_packed[(j4idx * n4 + k4_2) * 16 + dk_2 * 4]);
+ __m256d b3 = _mm256_load_pd(&B_packed[(j4idx * n4 + k4_3) * 16 + dk_3 * 4]);
+ acc0 = _mm256_fmadd_pd(a0, b0, acc0);
+ acc1 = _mm256_fmadd_pd(a1, b1, acc1);
+ acc2 = _mm256_fmadd_pd(a2, b2, acc2);
+ acc3 = _mm256_fmadd_pd(a3, b3, acc3);
+ }
+ acc0 = _mm256_add_pd(acc0, acc1);
+ acc2 = _mm256_add_pd(acc2, acc3);
+ acc0 = _mm256_add_pd(acc0, acc2);
+ for (; k < k_end; k++) {
+ size_t k4 = k / 4, dk = k % 4;
__m256d a_bcast = _mm256_set1_pd(A[i * n + k]);
__m256d b_val = _mm256_load_pd(&B_packed[(j4idx * n4 + k4) * 16 + dk * 4]);
- acc_vec = _mm256_fmadd_pd(a_bcast, b_val, acc_vec);
+ acc0 = _mm256_fmadd_pd(a_bcast, b_val, acc0);
}
double tmp[4] __attribute__((aligned(32)));
- _mm256_store_pd(tmp, acc_vec);
+ _mm256_store_pd(tmp, acc0);
for (size_t dj = 0; dj < 4; dj++) acc[li * tj + lj + dj] += tmp[dj];
}
for (size_t j = jj + (tj / 4) * 4; j < j_end; j++) {
@@ -689,22 +764,47 @@ int matmul_avx512_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, con
memset(acc, 0, ti * tj * sizeof(double));
for (size_t kk = 0; kk < n; kk += kb) {
- size_t k_end = (kk + kb < n) ? kk + kb : n;
+ size_t k_end = (kk + kb < n) ? kk + kb : n;
+ size_t k_end8 = kk + (k_end - kk) / 8 * 8;
for (size_t i = ii; i < i_end; i++) {
size_t li = i - ii;
for (size_t j = jj; j + 8 <= j_end; j += 8) {
- size_t lj = j - jj;
- size_t j8idx = j / 8;
- __m512d acc_vec = _mm512_setzero_pd();
- for (size_t k = kk; k < k_end; k++) {
- size_t k8 = k / 8;
- size_t dk = k % 8;
+ size_t lj = j - jj;
+ size_t j8idx = j / 8;
+ __m512d acc0 = _mm512_setzero_pd();
+ __m512d acc1 = _mm512_setzero_pd();
+ __m512d acc2 = _mm512_setzero_pd();
+ __m512d acc3 = _mm512_setzero_pd();
+ size_t k = kk;
+ for (; k + 4 <= k_end8; k += 4) {
+ size_t k8_0 = k / 8, dk_0 = k % 8;
+ size_t k8_1 = (k + 1) / 8, dk_1 = (k + 1) % 8;
+ size_t k8_2 = (k + 2) / 8, dk_2 = (k + 2) % 8;
+ size_t k8_3 = (k + 3) / 8, dk_3 = (k + 3) % 8;
+ __m512d a0 = _mm512_set1_pd(A[i * n + k]);
+ __m512d a1 = _mm512_set1_pd(A[i * n + k + 1]);
+ __m512d a2 = _mm512_set1_pd(A[i * n + k + 2]);
+ __m512d a3 = _mm512_set1_pd(A[i * n + k + 3]);
+ __m512d b0 = _mm512_load_pd(&B_packed[(j8idx * n8 + k8_0) * 64 + dk_0 * 8]);
+ __m512d b1 = _mm512_load_pd(&B_packed[(j8idx * n8 + k8_1) * 64 + dk_1 * 8]);
+ __m512d b2 = _mm512_load_pd(&B_packed[(j8idx * n8 + k8_2) * 64 + dk_2 * 8]);
+ __m512d b3 = _mm512_load_pd(&B_packed[(j8idx * n8 + k8_3) * 64 + dk_3 * 8]);
+ acc0 = _mm512_fmadd_pd(a0, b0, acc0);
+ acc1 = _mm512_fmadd_pd(a1, b1, acc1);
+ acc2 = _mm512_fmadd_pd(a2, b2, acc2);
+ acc3 = _mm512_fmadd_pd(a3, b3, acc3);
+ }
+ acc0 = _mm512_add_pd(acc0, acc1);
+ acc2 = _mm512_add_pd(acc2, acc3);
+ acc0 = _mm512_add_pd(acc0, acc2);
+ for (; k < k_end; k++) {
+ size_t k8 = k / 8, dk = k % 8;
__m512d a_bcast = _mm512_set1_pd(A[i * n + k]);
__m512d b_val = _mm512_load_pd(&B_packed[(j8idx * n8 + k8) * 64 + dk * 8]);
- acc_vec = _mm512_fmadd_pd(a_bcast, b_val, acc_vec);
+ acc0 = _mm512_fmadd_pd(a_bcast, b_val, acc0);
}
double tmp[8] __attribute__((aligned(64)));
- _mm512_store_pd(tmp, acc_vec);
+ _mm512_store_pd(tmp, acc0);
for (size_t dj = 0; dj < 8; dj++) acc[li * tj + lj + dj] += tmp[dj];
}
for (size_t j = jj + (tj / 8) * 8; j < j_end; j++) {