matmul.c

Matrix multiplication helper library
git clone git://git.finwo.net/lib/matmul.c
Log | Files | Refs | README | LICENSE

commit df550e3004136b431d34b051cde05616907d9f4b
parent 560fa4c4e737269b1fc591e34d483f3385c56934
Author: finwo <finwo@pm.me>
Date:   Mon, 20 Apr 2026 17:47:17 +0200

More performance improvements

Diffstat:
Msrc/matmul.c | 824+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++--------------
1 file changed, 684 insertions(+), 140 deletions(-)

diff --git a/src/matmul.c b/src/matmul.c @@ -320,6 +320,8 @@ int matmul_avx2_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const if (posix_memalign((void **)&B_packed, 64, p8 * n8 * 64 * sizeof(float)) != 0) return -1; pack_b_f32(n, p, B, B_packed); + float inv_scale = (scale > 1.0) ? 1.0f / (float)scale : 1.0f; + #pragma omp parallel for schedule(static) for (size_t ii = 0; ii < m; ii += ib) { size_t i_end = (ii + ib < m) ? ii + ib : m; @@ -330,28 +332,163 @@ int matmul_avx2_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const float acc[64 * 64]; memset(acc, 0, ti * tj * sizeof(float)); - for (size_t kk = 0; kk < n; kk += kb) { - size_t k_end = (kk + kb < n) ? kk + kb : n; - size_t k_end8 = kk + (k_end - kk) / 8 * 8; - for (size_t i = ii; i < i_end; i++) { - size_t li = i - ii; - for (size_t j = jj; j + 8 <= j_end; j += 8) { - size_t lj = j - jj; - size_t j8idx = j / 8; - __m256 acc0 = _mm256_setzero_ps(); - __m256 acc1 = _mm256_setzero_ps(); - __m256 acc2 = _mm256_setzero_ps(); - __m256 acc3 = _mm256_setzero_ps(); - size_t k = kk; + for (size_t i4 = ii; i4 + 4 <= i_end; i4 += 4) { + size_t li0 = i4 - ii; + size_t li1 = li0 + 1; + size_t li2 = li0 + 2; + size_t li3 = li0 + 3; + size_t rn0 = i4 * n; + size_t rn1 = (i4 + 1) * n; + size_t rn2 = (i4 + 2) * n; + size_t rn3 = (i4 + 3) * n; + + for (size_t j = jj; j + 8 <= j_end; j += 8) { + size_t lj = j - jj; + size_t j8idx = j / 8; + __m256 acc00 = _mm256_setzero_ps(); + __m256 acc01 = _mm256_setzero_ps(); + __m256 acc02 = _mm256_setzero_ps(); + __m256 acc03 = _mm256_setzero_ps(); + __m256 acc10 = _mm256_setzero_ps(); + __m256 acc11 = _mm256_setzero_ps(); + __m256 acc12 = _mm256_setzero_ps(); + __m256 acc13 = _mm256_setzero_ps(); + __m256 acc20 = _mm256_setzero_ps(); + __m256 acc21 = _mm256_setzero_ps(); + __m256 acc22 = _mm256_setzero_ps(); + __m256 acc23 = _mm256_setzero_ps(); + __m256 acc30 = _mm256_setzero_ps(); + __m256 acc31 = _mm256_setzero_ps(); + __m256 acc32 = _mm256_setzero_ps(); + __m256 acc33 = _mm256_setzero_ps(); + + for (size_t kk = 0; kk < n; kk += kb) { + size_t k_end = (kk + kb < n) ? kk + kb : n; + size_t k_end8 = kk + (k_end - kk) / 8 * 8; + size_t k = kk; + for (; k + 4 <= k_end8; k += 4) { + size_t k8_0 = k / 8, dk_0 = k % 8; + size_t k8_1 = (k + 1) / 8, dk_1 = (k + 1) % 8; + size_t k8_2 = (k + 2) / 8, dk_2 = (k + 2) % 8; + size_t k8_3 = (k + 3) / 8, dk_3 = (k + 3) % 8; + __m256 a0 = _mm256_set1_ps(A[rn0 + k]); + __m256 a1 = _mm256_set1_ps(A[rn0 + k + 1]); + __m256 a2 = _mm256_set1_ps(A[rn0 + k + 2]); + __m256 a3 = _mm256_set1_ps(A[rn0 + k + 3]); + __m256 b0 = _mm256_load_ps(&B_packed[(j8idx * n8 + k8_0) * 64 + dk_0 * 8]); + __m256 b1 = _mm256_load_ps(&B_packed[(j8idx * n8 + k8_1) * 64 + dk_1 * 8]); + __m256 b2 = _mm256_load_ps(&B_packed[(j8idx * n8 + k8_2) * 64 + dk_2 * 8]); + __m256 b3 = _mm256_load_ps(&B_packed[(j8idx * n8 + k8_3) * 64 + dk_3 * 8]); + acc00 = _mm256_fmadd_ps(a0, b0, acc00); + acc01 = _mm256_fmadd_ps(a1, b1, acc01); + acc02 = _mm256_fmadd_ps(a2, b2, acc02); + acc03 = _mm256_fmadd_ps(a3, b3, acc03); + a0 = _mm256_set1_ps(A[rn1 + k]); + a1 = _mm256_set1_ps(A[rn1 + k + 1]); + a2 = _mm256_set1_ps(A[rn1 + k + 2]); + a3 = _mm256_set1_ps(A[rn1 + k + 3]); + acc10 = _mm256_fmadd_ps(a0, b0, acc10); + acc11 = _mm256_fmadd_ps(a1, b1, acc11); + acc12 = _mm256_fmadd_ps(a2, b2, acc12); + acc13 = _mm256_fmadd_ps(a3, b3, acc13); + a0 = _mm256_set1_ps(A[rn2 + k]); + a1 = _mm256_set1_ps(A[rn2 + k + 1]); + a2 = _mm256_set1_ps(A[rn2 + k + 2]); + a3 = _mm256_set1_ps(A[rn2 + k + 3]); + acc20 = _mm256_fmadd_ps(a0, b0, acc20); + acc21 = _mm256_fmadd_ps(a1, b1, acc21); + acc22 = _mm256_fmadd_ps(a2, b2, acc22); + acc23 = _mm256_fmadd_ps(a3, b3, acc23); + a0 = _mm256_set1_ps(A[rn3 + k]); + a1 = _mm256_set1_ps(A[rn3 + k + 1]); + a2 = _mm256_set1_ps(A[rn3 + k + 2]); + a3 = _mm256_set1_ps(A[rn3 + k + 3]); + acc30 = _mm256_fmadd_ps(a0, b0, acc30); + acc31 = _mm256_fmadd_ps(a1, b1, acc31); + acc32 = _mm256_fmadd_ps(a2, b2, acc32); + acc33 = _mm256_fmadd_ps(a3, b3, acc33); + } + for (; k < k_end; k++) { + size_t k8 = k / 8, dk = k % 8; + __m256 a_bcast = _mm256_set1_ps(A[rn0 + k]); + __m256 b_val = _mm256_load_ps(&B_packed[(j8idx * n8 + k8) * 64 + dk * 8]); + acc00 = _mm256_fmadd_ps(a_bcast, b_val, acc00); + a_bcast = _mm256_set1_ps(A[rn1 + k]); + acc10 = _mm256_fmadd_ps(a_bcast, b_val, acc10); + a_bcast = _mm256_set1_ps(A[rn2 + k]); + acc20 = _mm256_fmadd_ps(a_bcast, b_val, acc20); + a_bcast = _mm256_set1_ps(A[rn3 + k]); + acc30 = _mm256_fmadd_ps(a_bcast, b_val, acc30); + } + } + + acc00 = _mm256_add_ps(acc00, acc01); + acc02 = _mm256_add_ps(acc02, acc03); + acc00 = _mm256_add_ps(acc00, acc02); + acc10 = _mm256_add_ps(acc10, acc11); + acc12 = _mm256_add_ps(acc12, acc13); + acc10 = _mm256_add_ps(acc10, acc12); + acc20 = _mm256_add_ps(acc20, acc21); + acc22 = _mm256_add_ps(acc22, acc23); + acc20 = _mm256_add_ps(acc20, acc22); + acc30 = _mm256_add_ps(acc30, acc31); + acc32 = _mm256_add_ps(acc32, acc33); + acc30 = _mm256_add_ps(acc30, acc32); + + float tmp[8] __attribute__((aligned(32))); + _mm256_store_ps(tmp, acc00); + for (size_t dj = 0; dj < 8; dj++) acc[li0 * tj + lj + dj] += tmp[dj]; + _mm256_store_ps(tmp, acc10); + for (size_t dj = 0; dj < 8; dj++) acc[li1 * tj + lj + dj] += tmp[dj]; + _mm256_store_ps(tmp, acc20); + for (size_t dj = 0; dj < 8; dj++) acc[li2 * tj + lj + dj] += tmp[dj]; + _mm256_store_ps(tmp, acc30); + for (size_t dj = 0; dj < 8; dj++) acc[li3 * tj + lj + dj] += tmp[dj]; + } + + for (size_t j = jj + (tj / 8) * 8; j < j_end; j++) { + size_t lj = j - jj; + double sum0 = 0.0; + double sum1 = 0.0; + double sum2 = 0.0; + double sum3 = 0.0; + for (size_t k = 0; k < n; k++) { + sum0 += (double)A[rn0 + k] * (double)B[k * p + j]; + sum1 += (double)A[rn1 + k] * (double)B[k * p + j]; + sum2 += (double)A[rn2 + k] * (double)B[k * p + j]; + sum3 += (double)A[rn3 + k] * (double)B[k * p + j]; + } + acc[li0 * tj + lj] += (float)sum0; + acc[li1 * tj + lj] += (float)sum1; + acc[li2 * tj + lj] += (float)sum2; + acc[li3 * tj + lj] += (float)sum3; + } + } + + for (size_t i = ii + (i_end - ii) / 4 * 4; i < i_end; i++) { + size_t li = i - ii; + size_t rn = i * n; + for (size_t j = jj; j + 8 <= j_end; j += 8) { + size_t lj = j - jj; + size_t j8idx = j / 8; + __m256 acc0 = _mm256_setzero_ps(); + __m256 acc1 = _mm256_setzero_ps(); + __m256 acc2 = _mm256_setzero_ps(); + __m256 acc3 = _mm256_setzero_ps(); + + for (size_t kk = 0; kk < n; kk += kb) { + size_t k_end = (kk + kb < n) ? kk + kb : n; + size_t k_end8 = kk + (k_end - kk) / 8 * 8; + size_t k = kk; for (; k + 4 <= k_end8; k += 4) { size_t k8_0 = k / 8, dk_0 = k % 8; size_t k8_1 = (k + 1) / 8, dk_1 = (k + 1) % 8; size_t k8_2 = (k + 2) / 8, dk_2 = (k + 2) % 8; size_t k8_3 = (k + 3) / 8, dk_3 = (k + 3) % 8; - __m256 a0 = _mm256_set1_ps(A[i * n + k]); - __m256 a1 = _mm256_set1_ps(A[i * n + k + 1]); - __m256 a2 = _mm256_set1_ps(A[i * n + k + 2]); - __m256 a3 = _mm256_set1_ps(A[i * n + k + 3]); + __m256 a0 = _mm256_set1_ps(A[rn + k]); + __m256 a1 = _mm256_set1_ps(A[rn + k + 1]); + __m256 a2 = _mm256_set1_ps(A[rn + k + 2]); + __m256 a3 = _mm256_set1_ps(A[rn + k + 3]); __m256 b0 = _mm256_load_ps(&B_packed[(j8idx * n8 + k8_0) * 64 + dk_0 * 8]); __m256 b1 = _mm256_load_ps(&B_packed[(j8idx * n8 + k8_1) * 64 + dk_1 * 8]); __m256 b2 = _mm256_load_ps(&B_packed[(j8idx * n8 + k8_2) * 64 + dk_2 * 8]); @@ -361,24 +498,26 @@ int matmul_avx2_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const acc2 = _mm256_fmadd_ps(a2, b2, acc2); acc3 = _mm256_fmadd_ps(a3, b3, acc3); } - acc0 = _mm256_add_ps(acc0, acc1); - acc2 = _mm256_add_ps(acc2, acc3); - acc0 = _mm256_add_ps(acc0, acc2); for (; k < k_end; k++) { size_t k8 = k / 8, dk = k % 8; - __m256 a_bcast = _mm256_set1_ps(A[i * n + k]); + __m256 a_bcast = _mm256_set1_ps(A[rn + k]); __m256 b_val = _mm256_load_ps(&B_packed[(j8idx * n8 + k8) * 64 + dk * 8]); acc0 = _mm256_fmadd_ps(a_bcast, b_val, acc0); } - float tmp[8] __attribute__((aligned(32))); - _mm256_store_ps(tmp, acc0); - for (size_t dj = 0; dj < 8; dj++) acc[li * tj + lj + dj] += tmp[dj]; } - for (size_t j = jj + (tj / 8) * 8; j < j_end; j++) { - size_t lj = j - jj; - for (size_t k = kk; k < k_end; k++) { - acc[li * tj + lj] += (double)A[i * n + k] * (double)B[k * p + j]; - } + + acc0 = _mm256_add_ps(acc0, acc1); + acc2 = _mm256_add_ps(acc2, acc3); + acc0 = _mm256_add_ps(acc0, acc2); + + float tmp[8] __attribute__((aligned(32))); + _mm256_store_ps(tmp, acc0); + for (size_t dj = 0; dj < 8; dj++) acc[li * tj + lj + dj] += tmp[dj]; + } + for (size_t j = jj + (tj / 8) * 8; j < j_end; j++) { + size_t lj = j - jj; + for (size_t k = 0; k < n; k++) { + acc[li * tj + lj] += (double)A[rn + k] * (double)B[k * p + j]; } } } @@ -386,10 +525,8 @@ int matmul_avx2_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const for (size_t i = ii; i < i_end; i++) { size_t li = i - ii; for (size_t j = jj; j < j_end; j++) { - size_t lj = j - jj; - float v = acc[li * tj + lj]; - if (scale > 1.0) v /= (float)scale; - C[i * p + j] = v; + size_t lj = j - jj; + C[i * p + j] = acc[li * tj + lj] * inv_scale; } } } @@ -401,8 +538,7 @@ int matmul_avx2_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const for (size_t k = 0; k < n; k++) { sum += (double)A[i * n + k] * (double)B[k * p + j]; } - if (scale > 1.0) sum /= scale; - C[i * p + j] = (float)sum; + C[i * p + j] = (float)(sum * inv_scale); } } @@ -441,6 +577,8 @@ int matmul_avx512_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, cons if (posix_memalign((void **)&B_packed, 64, p16 * n16 * 256 * sizeof(float)) != 0) return -1; pack_b_f32_512(n, p, B, B_packed); + float inv_scale = (scale > 1.0) ? 1.0f / (float)scale : 1.0f; + #pragma omp parallel for schedule(static) for (size_t ii = 0; ii < m; ii += ib) { size_t i_end = (ii + ib < m) ? ii + ib : m; @@ -451,28 +589,163 @@ int matmul_avx512_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, cons float acc[64 * 64]; memset(acc, 0, ti * tj * sizeof(float)); - for (size_t kk = 0; kk < n; kk += kb) { - size_t k_end = (kk + kb < n) ? kk + kb : n; - size_t k_end16 = kk + (k_end - kk) / 16 * 16; - for (size_t i = ii; i < i_end; i++) { - size_t li = i - ii; - for (size_t j = jj; j + 16 <= j_end; j += 16) { - size_t lj = j - jj; - size_t j16idx = j / 16; - __m512 acc0 = _mm512_setzero_ps(); - __m512 acc1 = _mm512_setzero_ps(); - __m512 acc2 = _mm512_setzero_ps(); - __m512 acc3 = _mm512_setzero_ps(); - size_t k = kk; + for (size_t i4 = ii; i4 + 4 <= i_end; i4 += 4) { + size_t li0 = i4 - ii; + size_t li1 = li0 + 1; + size_t li2 = li0 + 2; + size_t li3 = li0 + 3; + size_t rn0 = i4 * n; + size_t rn1 = (i4 + 1) * n; + size_t rn2 = (i4 + 2) * n; + size_t rn3 = (i4 + 3) * n; + + for (size_t j = jj; j + 16 <= j_end; j += 16) { + size_t lj = j - jj; + size_t j16idx = j / 16; + __m512 acc00 = _mm512_setzero_ps(); + __m512 acc01 = _mm512_setzero_ps(); + __m512 acc02 = _mm512_setzero_ps(); + __m512 acc03 = _mm512_setzero_ps(); + __m512 acc10 = _mm512_setzero_ps(); + __m512 acc11 = _mm512_setzero_ps(); + __m512 acc12 = _mm512_setzero_ps(); + __m512 acc13 = _mm512_setzero_ps(); + __m512 acc20 = _mm512_setzero_ps(); + __m512 acc21 = _mm512_setzero_ps(); + __m512 acc22 = _mm512_setzero_ps(); + __m512 acc23 = _mm512_setzero_ps(); + __m512 acc30 = _mm512_setzero_ps(); + __m512 acc31 = _mm512_setzero_ps(); + __m512 acc32 = _mm512_setzero_ps(); + __m512 acc33 = _mm512_setzero_ps(); + + for (size_t kk = 0; kk < n; kk += kb) { + size_t k_end = (kk + kb < n) ? kk + kb : n; + size_t k_end16 = kk + (k_end - kk) / 16 * 16; + size_t k = kk; + for (; k + 4 <= k_end16; k += 4) { + size_t k16_0 = k / 16, dk_0 = k % 16; + size_t k16_1 = (k + 1) / 16, dk_1 = (k + 1) % 16; + size_t k16_2 = (k + 2) / 16, dk_2 = (k + 2) % 16; + size_t k16_3 = (k + 3) / 16, dk_3 = (k + 3) % 16; + __m512 a0 = _mm512_set1_ps(A[rn0 + k]); + __m512 a1 = _mm512_set1_ps(A[rn0 + k + 1]); + __m512 a2 = _mm512_set1_ps(A[rn0 + k + 2]); + __m512 a3 = _mm512_set1_ps(A[rn0 + k + 3]); + __m512 b0 = _mm512_load_ps(&B_packed[(j16idx * n16 + k16_0) * 256 + dk_0 * 16]); + __m512 b1 = _mm512_load_ps(&B_packed[(j16idx * n16 + k16_1) * 256 + dk_1 * 16]); + __m512 b2 = _mm512_load_ps(&B_packed[(j16idx * n16 + k16_2) * 256 + dk_2 * 16]); + __m512 b3 = _mm512_load_ps(&B_packed[(j16idx * n16 + k16_3) * 256 + dk_3 * 16]); + acc00 = _mm512_fmadd_ps(a0, b0, acc00); + acc01 = _mm512_fmadd_ps(a1, b1, acc01); + acc02 = _mm512_fmadd_ps(a2, b2, acc02); + acc03 = _mm512_fmadd_ps(a3, b3, acc03); + a0 = _mm512_set1_ps(A[rn1 + k]); + a1 = _mm512_set1_ps(A[rn1 + k + 1]); + a2 = _mm512_set1_ps(A[rn1 + k + 2]); + a3 = _mm512_set1_ps(A[rn1 + k + 3]); + acc10 = _mm512_fmadd_ps(a0, b0, acc10); + acc11 = _mm512_fmadd_ps(a1, b1, acc11); + acc12 = _mm512_fmadd_ps(a2, b2, acc12); + acc13 = _mm512_fmadd_ps(a3, b3, acc13); + a0 = _mm512_set1_ps(A[rn2 + k]); + a1 = _mm512_set1_ps(A[rn2 + k + 1]); + a2 = _mm512_set1_ps(A[rn2 + k + 2]); + a3 = _mm512_set1_ps(A[rn2 + k + 3]); + acc20 = _mm512_fmadd_ps(a0, b0, acc20); + acc21 = _mm512_fmadd_ps(a1, b1, acc21); + acc22 = _mm512_fmadd_ps(a2, b2, acc22); + acc23 = _mm512_fmadd_ps(a3, b3, acc23); + a0 = _mm512_set1_ps(A[rn3 + k]); + a1 = _mm512_set1_ps(A[rn3 + k + 1]); + a2 = _mm512_set1_ps(A[rn3 + k + 2]); + a3 = _mm512_set1_ps(A[rn3 + k + 3]); + acc30 = _mm512_fmadd_ps(a0, b0, acc30); + acc31 = _mm512_fmadd_ps(a1, b1, acc31); + acc32 = _mm512_fmadd_ps(a2, b2, acc32); + acc33 = _mm512_fmadd_ps(a3, b3, acc33); + } + for (; k < k_end; k++) { + size_t k16 = k / 16, dk = k % 16; + __m512 a_bcast = _mm512_set1_ps(A[rn0 + k]); + __m512 b_val = _mm512_load_ps(&B_packed[(j16idx * n16 + k16) * 256 + dk * 16]); + acc00 = _mm512_fmadd_ps(a_bcast, b_val, acc00); + a_bcast = _mm512_set1_ps(A[rn1 + k]); + acc10 = _mm512_fmadd_ps(a_bcast, b_val, acc10); + a_bcast = _mm512_set1_ps(A[rn2 + k]); + acc20 = _mm512_fmadd_ps(a_bcast, b_val, acc20); + a_bcast = _mm512_set1_ps(A[rn3 + k]); + acc30 = _mm512_fmadd_ps(a_bcast, b_val, acc30); + } + } + + acc00 = _mm512_add_ps(acc00, acc01); + acc02 = _mm512_add_ps(acc02, acc03); + acc00 = _mm512_add_ps(acc00, acc02); + acc10 = _mm512_add_ps(acc10, acc11); + acc12 = _mm512_add_ps(acc12, acc13); + acc10 = _mm512_add_ps(acc10, acc12); + acc20 = _mm512_add_ps(acc20, acc21); + acc22 = _mm512_add_ps(acc22, acc23); + acc20 = _mm512_add_ps(acc20, acc22); + acc30 = _mm512_add_ps(acc30, acc31); + acc32 = _mm512_add_ps(acc32, acc33); + acc30 = _mm512_add_ps(acc30, acc32); + + float tmp[16] __attribute__((aligned(64))); + _mm512_store_ps(tmp, acc00); + for (size_t dj = 0; dj < 16; dj++) acc[li0 * tj + lj + dj] += tmp[dj]; + _mm512_store_ps(tmp, acc10); + for (size_t dj = 0; dj < 16; dj++) acc[li1 * tj + lj + dj] += tmp[dj]; + _mm512_store_ps(tmp, acc20); + for (size_t dj = 0; dj < 16; dj++) acc[li2 * tj + lj + dj] += tmp[dj]; + _mm512_store_ps(tmp, acc30); + for (size_t dj = 0; dj < 16; dj++) acc[li3 * tj + lj + dj] += tmp[dj]; + } + + for (size_t j = jj + (tj / 16) * 16; j < j_end; j++) { + size_t lj = j - jj; + double sum0 = 0.0; + double sum1 = 0.0; + double sum2 = 0.0; + double sum3 = 0.0; + for (size_t k = 0; k < n; k++) { + sum0 += (double)A[rn0 + k] * (double)B[k * p + j]; + sum1 += (double)A[rn1 + k] * (double)B[k * p + j]; + sum2 += (double)A[rn2 + k] * (double)B[k * p + j]; + sum3 += (double)A[rn3 + k] * (double)B[k * p + j]; + } + acc[li0 * tj + lj] += (float)sum0; + acc[li1 * tj + lj] += (float)sum1; + acc[li2 * tj + lj] += (float)sum2; + acc[li3 * tj + lj] += (float)sum3; + } + } + + for (size_t i = ii + (i_end - ii) / 4 * 4; i < i_end; i++) { + size_t li = i - ii; + size_t rn = i * n; + for (size_t j = jj; j + 16 <= j_end; j += 16) { + size_t lj = j - jj; + size_t j16idx = j / 16; + __m512 acc0 = _mm512_setzero_ps(); + __m512 acc1 = _mm512_setzero_ps(); + __m512 acc2 = _mm512_setzero_ps(); + __m512 acc3 = _mm512_setzero_ps(); + + for (size_t kk = 0; kk < n; kk += kb) { + size_t k_end = (kk + kb < n) ? kk + kb : n; + size_t k_end16 = kk + (k_end - kk) / 16 * 16; + size_t k = kk; for (; k + 4 <= k_end16; k += 4) { size_t k16_0 = k / 16, dk_0 = k % 16; size_t k16_1 = (k + 1) / 16, dk_1 = (k + 1) % 16; size_t k16_2 = (k + 2) / 16, dk_2 = (k + 2) % 16; size_t k16_3 = (k + 3) / 16, dk_3 = (k + 3) % 16; - __m512 a0 = _mm512_set1_ps(A[i * n + k]); - __m512 a1 = _mm512_set1_ps(A[i * n + k + 1]); - __m512 a2 = _mm512_set1_ps(A[i * n + k + 2]); - __m512 a3 = _mm512_set1_ps(A[i * n + k + 3]); + __m512 a0 = _mm512_set1_ps(A[rn + k]); + __m512 a1 = _mm512_set1_ps(A[rn + k + 1]); + __m512 a2 = _mm512_set1_ps(A[rn + k + 2]); + __m512 a3 = _mm512_set1_ps(A[rn + k + 3]); __m512 b0 = _mm512_load_ps(&B_packed[(j16idx * n16 + k16_0) * 256 + dk_0 * 16]); __m512 b1 = _mm512_load_ps(&B_packed[(j16idx * n16 + k16_1) * 256 + dk_1 * 16]); __m512 b2 = _mm512_load_ps(&B_packed[(j16idx * n16 + k16_2) * 256 + dk_2 * 16]); @@ -482,24 +755,26 @@ int matmul_avx512_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, cons acc2 = _mm512_fmadd_ps(a2, b2, acc2); acc3 = _mm512_fmadd_ps(a3, b3, acc3); } - acc0 = _mm512_add_ps(acc0, acc1); - acc2 = _mm512_add_ps(acc2, acc3); - acc0 = _mm512_add_ps(acc0, acc2); for (; k < k_end; k++) { size_t k16 = k / 16, dk = k % 16; - __m512 a_bcast = _mm512_set1_ps(A[i * n + k]); + __m512 a_bcast = _mm512_set1_ps(A[rn + k]); __m512 b_val = _mm512_load_ps(&B_packed[(j16idx * n16 + k16) * 256 + dk * 16]); acc0 = _mm512_fmadd_ps(a_bcast, b_val, acc0); } - float tmp[16] __attribute__((aligned(64))); - _mm512_store_ps(tmp, acc0); - for (size_t dj = 0; dj < 16; dj++) acc[li * tj + lj + dj] += tmp[dj]; } - for (size_t j = jj + (tj / 16) * 16; j < j_end; j++) { - size_t lj = j - jj; - for (size_t k = kk; k < k_end; k++) { - acc[li * tj + lj] += (double)A[i * n + k] * (double)B[k * p + j]; - } + + acc0 = _mm512_add_ps(acc0, acc1); + acc2 = _mm512_add_ps(acc2, acc3); + acc0 = _mm512_add_ps(acc0, acc2); + + float tmp[16] __attribute__((aligned(64))); + _mm512_store_ps(tmp, acc0); + for (size_t dj = 0; dj < 16; dj++) acc[li * tj + lj + dj] += tmp[dj]; + } + for (size_t j = jj + (tj / 16) * 16; j < j_end; j++) { + size_t lj = j - jj; + for (size_t k = 0; k < n; k++) { + acc[li * tj + lj] += (double)A[rn + k] * (double)B[k * p + j]; } } } @@ -507,10 +782,8 @@ int matmul_avx512_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, cons for (size_t i = ii; i < i_end; i++) { size_t li = i - ii; for (size_t j = jj; j < j_end; j++) { - size_t lj = j - jj; - float v = acc[li * tj + lj]; - if (scale > 1.0) v /= (float)scale; - C[i * p + j] = v; + size_t lj = j - jj; + C[i * p + j] = acc[li * tj + lj] * inv_scale; } } } @@ -522,8 +795,7 @@ int matmul_avx512_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, cons for (size_t k = 0; k < n; k++) { sum += (double)A[i * n + k] * (double)B[k * p + j]; } - if (scale > 1.0) sum /= scale; - C[i * p + j] = (float)sum; + C[i * p + j] = (float)(sum * inv_scale); } } @@ -632,6 +904,8 @@ int matmul_avx2_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, const if (posix_memalign((void **)&B_packed, 64, p4 * n4 * 16 * sizeof(double)) != 0) return -1; pack_b_f64(n, p, B, B_packed); + double inv_scale = (scale > 1.0) ? 1.0 / scale : 1.0; + #pragma omp parallel for schedule(static) for (size_t ii = 0; ii < m; ii += ib) { size_t i_end = (ii + ib < m) ? ii + ib : m; @@ -642,28 +916,163 @@ int matmul_avx2_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, const double acc[64 * 64]; memset(acc, 0, ti * tj * sizeof(double)); - for (size_t kk = 0; kk < n; kk += kb) { - size_t k_end = (kk + kb < n) ? kk + kb : n; - size_t k_end4 = kk + (k_end - kk) / 4 * 4; - for (size_t i = ii; i < i_end; i++) { - size_t li = i - ii; - for (size_t j = jj; j + 4 <= j_end; j += 4) { - size_t lj = j - jj; - size_t j4idx = j / 4; - __m256d acc0 = _mm256_setzero_pd(); - __m256d acc1 = _mm256_setzero_pd(); - __m256d acc2 = _mm256_setzero_pd(); - __m256d acc3 = _mm256_setzero_pd(); - size_t k = kk; + for (size_t i4 = ii; i4 + 4 <= i_end; i4 += 4) { + size_t li0 = i4 - ii; + size_t li1 = li0 + 1; + size_t li2 = li0 + 2; + size_t li3 = li0 + 3; + size_t rn0 = i4 * n; + size_t rn1 = (i4 + 1) * n; + size_t rn2 = (i4 + 2) * n; + size_t rn3 = (i4 + 3) * n; + + for (size_t j = jj; j + 4 <= j_end; j += 4) { + size_t lj = j - jj; + size_t j4idx = j / 4; + __m256d acc00 = _mm256_setzero_pd(); + __m256d acc01 = _mm256_setzero_pd(); + __m256d acc02 = _mm256_setzero_pd(); + __m256d acc03 = _mm256_setzero_pd(); + __m256d acc10 = _mm256_setzero_pd(); + __m256d acc11 = _mm256_setzero_pd(); + __m256d acc12 = _mm256_setzero_pd(); + __m256d acc13 = _mm256_setzero_pd(); + __m256d acc20 = _mm256_setzero_pd(); + __m256d acc21 = _mm256_setzero_pd(); + __m256d acc22 = _mm256_setzero_pd(); + __m256d acc23 = _mm256_setzero_pd(); + __m256d acc30 = _mm256_setzero_pd(); + __m256d acc31 = _mm256_setzero_pd(); + __m256d acc32 = _mm256_setzero_pd(); + __m256d acc33 = _mm256_setzero_pd(); + + for (size_t kk = 0; kk < n; kk += kb) { + size_t k_end = (kk + kb < n) ? kk + kb : n; + size_t k_end4 = kk + (k_end - kk) / 4 * 4; + size_t k = kk; + for (; k + 4 <= k_end4; k += 4) { + size_t k4_0 = k / 4, dk_0 = k % 4; + size_t k4_1 = (k + 1) / 4, dk_1 = (k + 1) % 4; + size_t k4_2 = (k + 2) / 4, dk_2 = (k + 2) % 4; + size_t k4_3 = (k + 3) / 4, dk_3 = (k + 3) % 4; + __m256d a0 = _mm256_set1_pd(A[rn0 + k]); + __m256d a1 = _mm256_set1_pd(A[rn0 + k + 1]); + __m256d a2 = _mm256_set1_pd(A[rn0 + k + 2]); + __m256d a3 = _mm256_set1_pd(A[rn0 + k + 3]); + __m256d b0 = _mm256_load_pd(&B_packed[(j4idx * n4 + k4_0) * 16 + dk_0 * 4]); + __m256d b1 = _mm256_load_pd(&B_packed[(j4idx * n4 + k4_1) * 16 + dk_1 * 4]); + __m256d b2 = _mm256_load_pd(&B_packed[(j4idx * n4 + k4_2) * 16 + dk_2 * 4]); + __m256d b3 = _mm256_load_pd(&B_packed[(j4idx * n4 + k4_3) * 16 + dk_3 * 4]); + acc00 = _mm256_fmadd_pd(a0, b0, acc00); + acc01 = _mm256_fmadd_pd(a1, b1, acc01); + acc02 = _mm256_fmadd_pd(a2, b2, acc02); + acc03 = _mm256_fmadd_pd(a3, b3, acc03); + a0 = _mm256_set1_pd(A[rn1 + k]); + a1 = _mm256_set1_pd(A[rn1 + k + 1]); + a2 = _mm256_set1_pd(A[rn1 + k + 2]); + a3 = _mm256_set1_pd(A[rn1 + k + 3]); + acc10 = _mm256_fmadd_pd(a0, b0, acc10); + acc11 = _mm256_fmadd_pd(a1, b1, acc11); + acc12 = _mm256_fmadd_pd(a2, b2, acc12); + acc13 = _mm256_fmadd_pd(a3, b3, acc13); + a0 = _mm256_set1_pd(A[rn2 + k]); + a1 = _mm256_set1_pd(A[rn2 + k + 1]); + a2 = _mm256_set1_pd(A[rn2 + k + 2]); + a3 = _mm256_set1_pd(A[rn2 + k + 3]); + acc20 = _mm256_fmadd_pd(a0, b0, acc20); + acc21 = _mm256_fmadd_pd(a1, b1, acc21); + acc22 = _mm256_fmadd_pd(a2, b2, acc22); + acc23 = _mm256_fmadd_pd(a3, b3, acc23); + a0 = _mm256_set1_pd(A[rn3 + k]); + a1 = _mm256_set1_pd(A[rn3 + k + 1]); + a2 = _mm256_set1_pd(A[rn3 + k + 2]); + a3 = _mm256_set1_pd(A[rn3 + k + 3]); + acc30 = _mm256_fmadd_pd(a0, b0, acc30); + acc31 = _mm256_fmadd_pd(a1, b1, acc31); + acc32 = _mm256_fmadd_pd(a2, b2, acc32); + acc33 = _mm256_fmadd_pd(a3, b3, acc33); + } + for (; k < k_end; k++) { + size_t k4 = k / 4, dk = k % 4; + __m256d a_bcast = _mm256_set1_pd(A[rn0 + k]); + __m256d b_val = _mm256_load_pd(&B_packed[(j4idx * n4 + k4) * 16 + dk * 4]); + acc00 = _mm256_fmadd_pd(a_bcast, b_val, acc00); + a_bcast = _mm256_set1_pd(A[rn1 + k]); + acc10 = _mm256_fmadd_pd(a_bcast, b_val, acc10); + a_bcast = _mm256_set1_pd(A[rn2 + k]); + acc20 = _mm256_fmadd_pd(a_bcast, b_val, acc20); + a_bcast = _mm256_set1_pd(A[rn3 + k]); + acc30 = _mm256_fmadd_pd(a_bcast, b_val, acc30); + } + } + + acc00 = _mm256_add_pd(acc00, acc01); + acc02 = _mm256_add_pd(acc02, acc03); + acc00 = _mm256_add_pd(acc00, acc02); + acc10 = _mm256_add_pd(acc10, acc11); + acc12 = _mm256_add_pd(acc12, acc13); + acc10 = _mm256_add_pd(acc10, acc12); + acc20 = _mm256_add_pd(acc20, acc21); + acc22 = _mm256_add_pd(acc22, acc23); + acc20 = _mm256_add_pd(acc20, acc22); + acc30 = _mm256_add_pd(acc30, acc31); + acc32 = _mm256_add_pd(acc32, acc33); + acc30 = _mm256_add_pd(acc30, acc32); + + double tmp[4] __attribute__((aligned(32))); + _mm256_store_pd(tmp, acc00); + for (size_t dj = 0; dj < 4; dj++) acc[li0 * tj + lj + dj] += tmp[dj]; + _mm256_store_pd(tmp, acc10); + for (size_t dj = 0; dj < 4; dj++) acc[li1 * tj + lj + dj] += tmp[dj]; + _mm256_store_pd(tmp, acc20); + for (size_t dj = 0; dj < 4; dj++) acc[li2 * tj + lj + dj] += tmp[dj]; + _mm256_store_pd(tmp, acc30); + for (size_t dj = 0; dj < 4; dj++) acc[li3 * tj + lj + dj] += tmp[dj]; + } + + for (size_t j = jj + (tj / 4) * 4; j < j_end; j++) { + size_t lj = j - jj; + double sum0 = 0.0; + double sum1 = 0.0; + double sum2 = 0.0; + double sum3 = 0.0; + for (size_t k = 0; k < n; k++) { + sum0 += A[rn0 + k] * B[k * p + j]; + sum1 += A[rn1 + k] * B[k * p + j]; + sum2 += A[rn2 + k] * B[k * p + j]; + sum3 += A[rn3 + k] * B[k * p + j]; + } + acc[li0 * tj + lj] += sum0; + acc[li1 * tj + lj] += sum1; + acc[li2 * tj + lj] += sum2; + acc[li3 * tj + lj] += sum3; + } + } + + for (size_t i = ii + (i_end - ii) / 4 * 4; i < i_end; i++) { + size_t li = i - ii; + size_t rn = i * n; + for (size_t j = jj; j + 4 <= j_end; j += 4) { + size_t lj = j - jj; + size_t j4idx = j / 4; + __m256d acc0 = _mm256_setzero_pd(); + __m256d acc1 = _mm256_setzero_pd(); + __m256d acc2 = _mm256_setzero_pd(); + __m256d acc3 = _mm256_setzero_pd(); + + for (size_t kk = 0; kk < n; kk += kb) { + size_t k_end = (kk + kb < n) ? kk + kb : n; + size_t k_end4 = kk + (k_end - kk) / 4 * 4; + size_t k = kk; for (; k + 4 <= k_end4; k += 4) { size_t k4_0 = k / 4, dk_0 = k % 4; size_t k4_1 = (k + 1) / 4, dk_1 = (k + 1) % 4; size_t k4_2 = (k + 2) / 4, dk_2 = (k + 2) % 4; size_t k4_3 = (k + 3) / 4, dk_3 = (k + 3) % 4; - __m256d a0 = _mm256_set1_pd(A[i * n + k]); - __m256d a1 = _mm256_set1_pd(A[i * n + k + 1]); - __m256d a2 = _mm256_set1_pd(A[i * n + k + 2]); - __m256d a3 = _mm256_set1_pd(A[i * n + k + 3]); + __m256d a0 = _mm256_set1_pd(A[rn + k]); + __m256d a1 = _mm256_set1_pd(A[rn + k + 1]); + __m256d a2 = _mm256_set1_pd(A[rn + k + 2]); + __m256d a3 = _mm256_set1_pd(A[rn + k + 3]); __m256d b0 = _mm256_load_pd(&B_packed[(j4idx * n4 + k4_0) * 16 + dk_0 * 4]); __m256d b1 = _mm256_load_pd(&B_packed[(j4idx * n4 + k4_1) * 16 + dk_1 * 4]); __m256d b2 = _mm256_load_pd(&B_packed[(j4idx * n4 + k4_2) * 16 + dk_2 * 4]); @@ -673,24 +1082,26 @@ int matmul_avx2_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, const acc2 = _mm256_fmadd_pd(a2, b2, acc2); acc3 = _mm256_fmadd_pd(a3, b3, acc3); } - acc0 = _mm256_add_pd(acc0, acc1); - acc2 = _mm256_add_pd(acc2, acc3); - acc0 = _mm256_add_pd(acc0, acc2); for (; k < k_end; k++) { size_t k4 = k / 4, dk = k % 4; - __m256d a_bcast = _mm256_set1_pd(A[i * n + k]); + __m256d a_bcast = _mm256_set1_pd(A[rn + k]); __m256d b_val = _mm256_load_pd(&B_packed[(j4idx * n4 + k4) * 16 + dk * 4]); acc0 = _mm256_fmadd_pd(a_bcast, b_val, acc0); } - double tmp[4] __attribute__((aligned(32))); - _mm256_store_pd(tmp, acc0); - for (size_t dj = 0; dj < 4; dj++) acc[li * tj + lj + dj] += tmp[dj]; } - for (size_t j = jj + (tj / 4) * 4; j < j_end; j++) { - size_t lj = j - jj; - for (size_t k = kk; k < k_end; k++) { - acc[li * tj + lj] += A[i * n + k] * B[k * p + j]; - } + + acc0 = _mm256_add_pd(acc0, acc1); + acc2 = _mm256_add_pd(acc2, acc3); + acc0 = _mm256_add_pd(acc0, acc2); + + double tmp[4] __attribute__((aligned(32))); + _mm256_store_pd(tmp, acc0); + for (size_t dj = 0; dj < 4; dj++) acc[li * tj + lj + dj] += tmp[dj]; + } + for (size_t j = jj + (tj / 4) * 4; j < j_end; j++) { + size_t lj = j - jj; + for (size_t k = 0; k < n; k++) { + acc[li * tj + lj] += A[rn + k] * B[k * p + j]; } } } @@ -698,10 +1109,8 @@ int matmul_avx2_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, const for (size_t i = ii; i < i_end; i++) { size_t li = i - ii; for (size_t j = jj; j < j_end; j++) { - size_t lj = j - jj; - double v = acc[li * tj + lj]; - if (scale > 1.0) v /= scale; - C[i * p + j] = v; + size_t lj = j - jj; + C[i * p + j] = acc[li * tj + lj] * inv_scale; } } } @@ -713,8 +1122,7 @@ int matmul_avx2_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, const for (size_t k = 0; k < n; k++) { sum += A[i * n + k] * B[k * p + j]; } - if (scale > 1.0) sum /= scale; - C[i * p + j] = sum; + C[i * p + j] = sum * inv_scale; } } @@ -753,6 +1161,8 @@ int matmul_avx512_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, con if (posix_memalign((void **)&B_packed, 64, p8 * n8 * 64 * sizeof(double)) != 0) return -1; pack_b_f64_512(n, p, B, B_packed); + double inv_scale = (scale > 1.0) ? 1.0 / scale : 1.0; + #pragma omp parallel for schedule(static) for (size_t ii = 0; ii < m; ii += ib) { size_t i_end = (ii + ib < m) ? ii + ib : m; @@ -763,28 +1173,163 @@ int matmul_avx512_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, con double acc[64 * 64]; memset(acc, 0, ti * tj * sizeof(double)); - for (size_t kk = 0; kk < n; kk += kb) { - size_t k_end = (kk + kb < n) ? kk + kb : n; - size_t k_end8 = kk + (k_end - kk) / 8 * 8; - for (size_t i = ii; i < i_end; i++) { - size_t li = i - ii; - for (size_t j = jj; j + 8 <= j_end; j += 8) { - size_t lj = j - jj; - size_t j8idx = j / 8; - __m512d acc0 = _mm512_setzero_pd(); - __m512d acc1 = _mm512_setzero_pd(); - __m512d acc2 = _mm512_setzero_pd(); - __m512d acc3 = _mm512_setzero_pd(); - size_t k = kk; + for (size_t i4 = ii; i4 + 4 <= i_end; i4 += 4) { + size_t li0 = i4 - ii; + size_t li1 = li0 + 1; + size_t li2 = li0 + 2; + size_t li3 = li0 + 3; + size_t rn0 = i4 * n; + size_t rn1 = (i4 + 1) * n; + size_t rn2 = (i4 + 2) * n; + size_t rn3 = (i4 + 3) * n; + + for (size_t j = jj; j + 8 <= j_end; j += 8) { + size_t lj = j - jj; + size_t j8idx = j / 8; + __m512d acc00 = _mm512_setzero_pd(); + __m512d acc01 = _mm512_setzero_pd(); + __m512d acc02 = _mm512_setzero_pd(); + __m512d acc03 = _mm512_setzero_pd(); + __m512d acc10 = _mm512_setzero_pd(); + __m512d acc11 = _mm512_setzero_pd(); + __m512d acc12 = _mm512_setzero_pd(); + __m512d acc13 = _mm512_setzero_pd(); + __m512d acc20 = _mm512_setzero_pd(); + __m512d acc21 = _mm512_setzero_pd(); + __m512d acc22 = _mm512_setzero_pd(); + __m512d acc23 = _mm512_setzero_pd(); + __m512d acc30 = _mm512_setzero_pd(); + __m512d acc31 = _mm512_setzero_pd(); + __m512d acc32 = _mm512_setzero_pd(); + __m512d acc33 = _mm512_setzero_pd(); + + for (size_t kk = 0; kk < n; kk += kb) { + size_t k_end = (kk + kb < n) ? kk + kb : n; + size_t k_end8 = kk + (k_end - kk) / 8 * 8; + size_t k = kk; for (; k + 4 <= k_end8; k += 4) { size_t k8_0 = k / 8, dk_0 = k % 8; size_t k8_1 = (k + 1) / 8, dk_1 = (k + 1) % 8; size_t k8_2 = (k + 2) / 8, dk_2 = (k + 2) % 8; size_t k8_3 = (k + 3) / 8, dk_3 = (k + 3) % 8; - __m512d a0 = _mm512_set1_pd(A[i * n + k]); - __m512d a1 = _mm512_set1_pd(A[i * n + k + 1]); - __m512d a2 = _mm512_set1_pd(A[i * n + k + 2]); - __m512d a3 = _mm512_set1_pd(A[i * n + k + 3]); + __m512d a0 = _mm512_set1_pd(A[rn0 + k]); + __m512d a1 = _mm512_set1_pd(A[rn0 + k + 1]); + __m512d a2 = _mm512_set1_pd(A[rn0 + k + 2]); + __m512d a3 = _mm512_set1_pd(A[rn0 + k + 3]); + __m512d b0 = _mm512_load_pd(&B_packed[(j8idx * n8 + k8_0) * 64 + dk_0 * 8]); + __m512d b1 = _mm512_load_pd(&B_packed[(j8idx * n8 + k8_1) * 64 + dk_1 * 8]); + __m512d b2 = _mm512_load_pd(&B_packed[(j8idx * n8 + k8_2) * 64 + dk_2 * 8]); + __m512d b3 = _mm512_load_pd(&B_packed[(j8idx * n8 + k8_3) * 64 + dk_3 * 8]); + acc00 = _mm512_fmadd_pd(a0, b0, acc00); + acc01 = _mm512_fmadd_pd(a1, b1, acc01); + acc02 = _mm512_fmadd_pd(a2, b2, acc02); + acc03 = _mm512_fmadd_pd(a3, b3, acc03); + a0 = _mm512_set1_pd(A[rn1 + k]); + a1 = _mm512_set1_pd(A[rn1 + k + 1]); + a2 = _mm512_set1_pd(A[rn1 + k + 2]); + a3 = _mm512_set1_pd(A[rn1 + k + 3]); + acc10 = _mm512_fmadd_pd(a0, b0, acc10); + acc11 = _mm512_fmadd_pd(a1, b1, acc11); + acc12 = _mm512_fmadd_pd(a2, b2, acc12); + acc13 = _mm512_fmadd_pd(a3, b3, acc13); + a0 = _mm512_set1_pd(A[rn2 + k]); + a1 = _mm512_set1_pd(A[rn2 + k + 1]); + a2 = _mm512_set1_pd(A[rn2 + k + 2]); + a3 = _mm512_set1_pd(A[rn2 + k + 3]); + acc20 = _mm512_fmadd_pd(a0, b0, acc20); + acc21 = _mm512_fmadd_pd(a1, b1, acc21); + acc22 = _mm512_fmadd_pd(a2, b2, acc22); + acc23 = _mm512_fmadd_pd(a3, b3, acc23); + a0 = _mm512_set1_pd(A[rn3 + k]); + a1 = _mm512_set1_pd(A[rn3 + k + 1]); + a2 = _mm512_set1_pd(A[rn3 + k + 2]); + a3 = _mm512_set1_pd(A[rn3 + k + 3]); + acc30 = _mm512_fmadd_pd(a0, b0, acc30); + acc31 = _mm512_fmadd_pd(a1, b1, acc31); + acc32 = _mm512_fmadd_pd(a2, b2, acc32); + acc33 = _mm512_fmadd_pd(a3, b3, acc33); + } + for (; k < k_end; k++) { + size_t k8 = k / 8, dk = k % 8; + __m512d a_bcast = _mm512_set1_pd(A[rn0 + k]); + __m512d b_val = _mm512_load_pd(&B_packed[(j8idx * n8 + k8) * 64 + dk * 8]); + acc00 = _mm512_fmadd_pd(a_bcast, b_val, acc00); + a_bcast = _mm512_set1_pd(A[rn1 + k]); + acc10 = _mm512_fmadd_pd(a_bcast, b_val, acc10); + a_bcast = _mm512_set1_pd(A[rn2 + k]); + acc20 = _mm512_fmadd_pd(a_bcast, b_val, acc20); + a_bcast = _mm512_set1_pd(A[rn3 + k]); + acc30 = _mm512_fmadd_pd(a_bcast, b_val, acc30); + } + } + + acc00 = _mm512_add_pd(acc00, acc01); + acc02 = _mm512_add_pd(acc02, acc03); + acc00 = _mm512_add_pd(acc00, acc02); + acc10 = _mm512_add_pd(acc10, acc11); + acc12 = _mm512_add_pd(acc12, acc13); + acc10 = _mm512_add_pd(acc10, acc12); + acc20 = _mm512_add_pd(acc20, acc21); + acc22 = _mm512_add_pd(acc22, acc23); + acc20 = _mm512_add_pd(acc20, acc22); + acc30 = _mm512_add_pd(acc30, acc31); + acc32 = _mm512_add_pd(acc32, acc33); + acc30 = _mm512_add_pd(acc30, acc32); + + double tmp[8] __attribute__((aligned(64))); + _mm512_store_pd(tmp, acc00); + for (size_t dj = 0; dj < 8; dj++) acc[li0 * tj + lj + dj] += tmp[dj]; + _mm512_store_pd(tmp, acc10); + for (size_t dj = 0; dj < 8; dj++) acc[li1 * tj + lj + dj] += tmp[dj]; + _mm512_store_pd(tmp, acc20); + for (size_t dj = 0; dj < 8; dj++) acc[li2 * tj + lj + dj] += tmp[dj]; + _mm512_store_pd(tmp, acc30); + for (size_t dj = 0; dj < 8; dj++) acc[li3 * tj + lj + dj] += tmp[dj]; + } + + for (size_t j = jj + (tj / 8) * 8; j < j_end; j++) { + size_t lj = j - jj; + double sum0 = 0.0; + double sum1 = 0.0; + double sum2 = 0.0; + double sum3 = 0.0; + for (size_t k = 0; k < n; k++) { + sum0 += A[rn0 + k] * B[k * p + j]; + sum1 += A[rn1 + k] * B[k * p + j]; + sum2 += A[rn2 + k] * B[k * p + j]; + sum3 += A[rn3 + k] * B[k * p + j]; + } + acc[li0 * tj + lj] += sum0; + acc[li1 * tj + lj] += sum1; + acc[li2 * tj + lj] += sum2; + acc[li3 * tj + lj] += sum3; + } + } + + for (size_t i = ii + (i_end - ii) / 4 * 4; i < i_end; i++) { + size_t li = i - ii; + size_t rn = i * n; + for (size_t j = jj; j + 8 <= j_end; j += 8) { + size_t lj = j - jj; + size_t j8idx = j / 8; + __m512d acc0 = _mm512_setzero_pd(); + __m512d acc1 = _mm512_setzero_pd(); + __m512d acc2 = _mm512_setzero_pd(); + __m512d acc3 = _mm512_setzero_pd(); + + for (size_t kk = 0; kk < n; kk += kb) { + size_t k_end = (kk + kb < n) ? kk + kb : n; + size_t k_end8 = kk + (k_end - kk) / 8 * 8; + size_t k = kk; + for (; k + 4 <= k_end8; k += 4) { + size_t k8_0 = k / 8, dk_0 = k % 8; + size_t k8_1 = (k + 1) / 8, dk_1 = (k + 1) % 8; + size_t k8_2 = (k + 2) / 8, dk_2 = (k + 2) % 8; + size_t k8_3 = (k + 3) / 8, dk_3 = (k + 3) % 8; + __m512d a0 = _mm512_set1_pd(A[rn + k]); + __m512d a1 = _mm512_set1_pd(A[rn + k + 1]); + __m512d a2 = _mm512_set1_pd(A[rn + k + 2]); + __m512d a3 = _mm512_set1_pd(A[rn + k + 3]); __m512d b0 = _mm512_load_pd(&B_packed[(j8idx * n8 + k8_0) * 64 + dk_0 * 8]); __m512d b1 = _mm512_load_pd(&B_packed[(j8idx * n8 + k8_1) * 64 + dk_1 * 8]); __m512d b2 = _mm512_load_pd(&B_packed[(j8idx * n8 + k8_2) * 64 + dk_2 * 8]); @@ -794,24 +1339,26 @@ int matmul_avx512_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, con acc2 = _mm512_fmadd_pd(a2, b2, acc2); acc3 = _mm512_fmadd_pd(a3, b3, acc3); } - acc0 = _mm512_add_pd(acc0, acc1); - acc2 = _mm512_add_pd(acc2, acc3); - acc0 = _mm512_add_pd(acc0, acc2); for (; k < k_end; k++) { size_t k8 = k / 8, dk = k % 8; - __m512d a_bcast = _mm512_set1_pd(A[i * n + k]); + __m512d a_bcast = _mm512_set1_pd(A[rn + k]); __m512d b_val = _mm512_load_pd(&B_packed[(j8idx * n8 + k8) * 64 + dk * 8]); acc0 = _mm512_fmadd_pd(a_bcast, b_val, acc0); } - double tmp[8] __attribute__((aligned(64))); - _mm512_store_pd(tmp, acc0); - for (size_t dj = 0; dj < 8; dj++) acc[li * tj + lj + dj] += tmp[dj]; } - for (size_t j = jj + (tj / 8) * 8; j < j_end; j++) { - size_t lj = j - jj; - for (size_t k = kk; k < k_end; k++) { - acc[li * tj + lj] += A[i * n + k] * B[k * p + j]; - } + + acc0 = _mm512_add_pd(acc0, acc1); + acc2 = _mm512_add_pd(acc2, acc3); + acc0 = _mm512_add_pd(acc0, acc2); + + double tmp[8] __attribute__((aligned(64))); + _mm512_store_pd(tmp, acc0); + for (size_t dj = 0; dj < 8; dj++) acc[li * tj + lj + dj] += tmp[dj]; + } + for (size_t j = jj + (tj / 8) * 8; j < j_end; j++) { + size_t lj = j - jj; + for (size_t k = 0; k < n; k++) { + acc[li * tj + lj] += A[rn + k] * B[k * p + j]; } } } @@ -819,10 +1366,8 @@ int matmul_avx512_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, con for (size_t i = ii; i < i_end; i++) { size_t li = i - ii; for (size_t j = jj; j < j_end; j++) { - size_t lj = j - jj; - double v = acc[li * tj + lj]; - if (scale > 1.0) v /= scale; - C[i * p + j] = v; + size_t lj = j - jj; + C[i * p + j] = acc[li * tj + lj] * inv_scale; } } } @@ -834,8 +1379,7 @@ int matmul_avx512_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, con for (size_t k = 0; k < n; k++) { sum += A[i * n + k] * B[k * p + j]; } - if (scale > 1.0) sum /= scale; - C[i * p + j] = sum; + C[i * p + j] = sum * inv_scale; } }