commit df550e3004136b431d34b051cde05616907d9f4b
parent 560fa4c4e737269b1fc591e34d483f3385c56934
Author: finwo <finwo@pm.me>
Date: Mon, 20 Apr 2026 17:47:17 +0200
More performance improvements
Diffstat:
| M | src/matmul.c | | | 824 | +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-------------- |
1 file changed, 684 insertions(+), 140 deletions(-)
diff --git a/src/matmul.c b/src/matmul.c
@@ -320,6 +320,8 @@ int matmul_avx2_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const
if (posix_memalign((void **)&B_packed, 64, p8 * n8 * 64 * sizeof(float)) != 0) return -1;
pack_b_f32(n, p, B, B_packed);
+ float inv_scale = (scale > 1.0) ? 1.0f / (float)scale : 1.0f;
+
#pragma omp parallel for schedule(static)
for (size_t ii = 0; ii < m; ii += ib) {
size_t i_end = (ii + ib < m) ? ii + ib : m;
@@ -330,28 +332,163 @@ int matmul_avx2_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const
float acc[64 * 64];
memset(acc, 0, ti * tj * sizeof(float));
- for (size_t kk = 0; kk < n; kk += kb) {
- size_t k_end = (kk + kb < n) ? kk + kb : n;
- size_t k_end8 = kk + (k_end - kk) / 8 * 8;
- for (size_t i = ii; i < i_end; i++) {
- size_t li = i - ii;
- for (size_t j = jj; j + 8 <= j_end; j += 8) {
- size_t lj = j - jj;
- size_t j8idx = j / 8;
- __m256 acc0 = _mm256_setzero_ps();
- __m256 acc1 = _mm256_setzero_ps();
- __m256 acc2 = _mm256_setzero_ps();
- __m256 acc3 = _mm256_setzero_ps();
- size_t k = kk;
+ for (size_t i4 = ii; i4 + 4 <= i_end; i4 += 4) {
+ size_t li0 = i4 - ii;
+ size_t li1 = li0 + 1;
+ size_t li2 = li0 + 2;
+ size_t li3 = li0 + 3;
+ size_t rn0 = i4 * n;
+ size_t rn1 = (i4 + 1) * n;
+ size_t rn2 = (i4 + 2) * n;
+ size_t rn3 = (i4 + 3) * n;
+
+ for (size_t j = jj; j + 8 <= j_end; j += 8) {
+ size_t lj = j - jj;
+ size_t j8idx = j / 8;
+ __m256 acc00 = _mm256_setzero_ps();
+ __m256 acc01 = _mm256_setzero_ps();
+ __m256 acc02 = _mm256_setzero_ps();
+ __m256 acc03 = _mm256_setzero_ps();
+ __m256 acc10 = _mm256_setzero_ps();
+ __m256 acc11 = _mm256_setzero_ps();
+ __m256 acc12 = _mm256_setzero_ps();
+ __m256 acc13 = _mm256_setzero_ps();
+ __m256 acc20 = _mm256_setzero_ps();
+ __m256 acc21 = _mm256_setzero_ps();
+ __m256 acc22 = _mm256_setzero_ps();
+ __m256 acc23 = _mm256_setzero_ps();
+ __m256 acc30 = _mm256_setzero_ps();
+ __m256 acc31 = _mm256_setzero_ps();
+ __m256 acc32 = _mm256_setzero_ps();
+ __m256 acc33 = _mm256_setzero_ps();
+
+ for (size_t kk = 0; kk < n; kk += kb) {
+ size_t k_end = (kk + kb < n) ? kk + kb : n;
+ size_t k_end8 = kk + (k_end - kk) / 8 * 8;
+ size_t k = kk;
+ for (; k + 4 <= k_end8; k += 4) {
+ size_t k8_0 = k / 8, dk_0 = k % 8;
+ size_t k8_1 = (k + 1) / 8, dk_1 = (k + 1) % 8;
+ size_t k8_2 = (k + 2) / 8, dk_2 = (k + 2) % 8;
+ size_t k8_3 = (k + 3) / 8, dk_3 = (k + 3) % 8;
+ __m256 a0 = _mm256_set1_ps(A[rn0 + k]);
+ __m256 a1 = _mm256_set1_ps(A[rn0 + k + 1]);
+ __m256 a2 = _mm256_set1_ps(A[rn0 + k + 2]);
+ __m256 a3 = _mm256_set1_ps(A[rn0 + k + 3]);
+ __m256 b0 = _mm256_load_ps(&B_packed[(j8idx * n8 + k8_0) * 64 + dk_0 * 8]);
+ __m256 b1 = _mm256_load_ps(&B_packed[(j8idx * n8 + k8_1) * 64 + dk_1 * 8]);
+ __m256 b2 = _mm256_load_ps(&B_packed[(j8idx * n8 + k8_2) * 64 + dk_2 * 8]);
+ __m256 b3 = _mm256_load_ps(&B_packed[(j8idx * n8 + k8_3) * 64 + dk_3 * 8]);
+ acc00 = _mm256_fmadd_ps(a0, b0, acc00);
+ acc01 = _mm256_fmadd_ps(a1, b1, acc01);
+ acc02 = _mm256_fmadd_ps(a2, b2, acc02);
+ acc03 = _mm256_fmadd_ps(a3, b3, acc03);
+ a0 = _mm256_set1_ps(A[rn1 + k]);
+ a1 = _mm256_set1_ps(A[rn1 + k + 1]);
+ a2 = _mm256_set1_ps(A[rn1 + k + 2]);
+ a3 = _mm256_set1_ps(A[rn1 + k + 3]);
+ acc10 = _mm256_fmadd_ps(a0, b0, acc10);
+ acc11 = _mm256_fmadd_ps(a1, b1, acc11);
+ acc12 = _mm256_fmadd_ps(a2, b2, acc12);
+ acc13 = _mm256_fmadd_ps(a3, b3, acc13);
+ a0 = _mm256_set1_ps(A[rn2 + k]);
+ a1 = _mm256_set1_ps(A[rn2 + k + 1]);
+ a2 = _mm256_set1_ps(A[rn2 + k + 2]);
+ a3 = _mm256_set1_ps(A[rn2 + k + 3]);
+ acc20 = _mm256_fmadd_ps(a0, b0, acc20);
+ acc21 = _mm256_fmadd_ps(a1, b1, acc21);
+ acc22 = _mm256_fmadd_ps(a2, b2, acc22);
+ acc23 = _mm256_fmadd_ps(a3, b3, acc23);
+ a0 = _mm256_set1_ps(A[rn3 + k]);
+ a1 = _mm256_set1_ps(A[rn3 + k + 1]);
+ a2 = _mm256_set1_ps(A[rn3 + k + 2]);
+ a3 = _mm256_set1_ps(A[rn3 + k + 3]);
+ acc30 = _mm256_fmadd_ps(a0, b0, acc30);
+ acc31 = _mm256_fmadd_ps(a1, b1, acc31);
+ acc32 = _mm256_fmadd_ps(a2, b2, acc32);
+ acc33 = _mm256_fmadd_ps(a3, b3, acc33);
+ }
+ for (; k < k_end; k++) {
+ size_t k8 = k / 8, dk = k % 8;
+ __m256 a_bcast = _mm256_set1_ps(A[rn0 + k]);
+ __m256 b_val = _mm256_load_ps(&B_packed[(j8idx * n8 + k8) * 64 + dk * 8]);
+ acc00 = _mm256_fmadd_ps(a_bcast, b_val, acc00);
+ a_bcast = _mm256_set1_ps(A[rn1 + k]);
+ acc10 = _mm256_fmadd_ps(a_bcast, b_val, acc10);
+ a_bcast = _mm256_set1_ps(A[rn2 + k]);
+ acc20 = _mm256_fmadd_ps(a_bcast, b_val, acc20);
+ a_bcast = _mm256_set1_ps(A[rn3 + k]);
+ acc30 = _mm256_fmadd_ps(a_bcast, b_val, acc30);
+ }
+ }
+
+ acc00 = _mm256_add_ps(acc00, acc01);
+ acc02 = _mm256_add_ps(acc02, acc03);
+ acc00 = _mm256_add_ps(acc00, acc02);
+ acc10 = _mm256_add_ps(acc10, acc11);
+ acc12 = _mm256_add_ps(acc12, acc13);
+ acc10 = _mm256_add_ps(acc10, acc12);
+ acc20 = _mm256_add_ps(acc20, acc21);
+ acc22 = _mm256_add_ps(acc22, acc23);
+ acc20 = _mm256_add_ps(acc20, acc22);
+ acc30 = _mm256_add_ps(acc30, acc31);
+ acc32 = _mm256_add_ps(acc32, acc33);
+ acc30 = _mm256_add_ps(acc30, acc32);
+
+ float tmp[8] __attribute__((aligned(32)));
+ _mm256_store_ps(tmp, acc00);
+ for (size_t dj = 0; dj < 8; dj++) acc[li0 * tj + lj + dj] += tmp[dj];
+ _mm256_store_ps(tmp, acc10);
+ for (size_t dj = 0; dj < 8; dj++) acc[li1 * tj + lj + dj] += tmp[dj];
+ _mm256_store_ps(tmp, acc20);
+ for (size_t dj = 0; dj < 8; dj++) acc[li2 * tj + lj + dj] += tmp[dj];
+ _mm256_store_ps(tmp, acc30);
+ for (size_t dj = 0; dj < 8; dj++) acc[li3 * tj + lj + dj] += tmp[dj];
+ }
+
+ for (size_t j = jj + (tj / 8) * 8; j < j_end; j++) {
+ size_t lj = j - jj;
+ double sum0 = 0.0;
+ double sum1 = 0.0;
+ double sum2 = 0.0;
+ double sum3 = 0.0;
+ for (size_t k = 0; k < n; k++) {
+ sum0 += (double)A[rn0 + k] * (double)B[k * p + j];
+ sum1 += (double)A[rn1 + k] * (double)B[k * p + j];
+ sum2 += (double)A[rn2 + k] * (double)B[k * p + j];
+ sum3 += (double)A[rn3 + k] * (double)B[k * p + j];
+ }
+ acc[li0 * tj + lj] += (float)sum0;
+ acc[li1 * tj + lj] += (float)sum1;
+ acc[li2 * tj + lj] += (float)sum2;
+ acc[li3 * tj + lj] += (float)sum3;
+ }
+ }
+
+ for (size_t i = ii + (i_end - ii) / 4 * 4; i < i_end; i++) {
+ size_t li = i - ii;
+ size_t rn = i * n;
+ for (size_t j = jj; j + 8 <= j_end; j += 8) {
+ size_t lj = j - jj;
+ size_t j8idx = j / 8;
+ __m256 acc0 = _mm256_setzero_ps();
+ __m256 acc1 = _mm256_setzero_ps();
+ __m256 acc2 = _mm256_setzero_ps();
+ __m256 acc3 = _mm256_setzero_ps();
+
+ for (size_t kk = 0; kk < n; kk += kb) {
+ size_t k_end = (kk + kb < n) ? kk + kb : n;
+ size_t k_end8 = kk + (k_end - kk) / 8 * 8;
+ size_t k = kk;
for (; k + 4 <= k_end8; k += 4) {
size_t k8_0 = k / 8, dk_0 = k % 8;
size_t k8_1 = (k + 1) / 8, dk_1 = (k + 1) % 8;
size_t k8_2 = (k + 2) / 8, dk_2 = (k + 2) % 8;
size_t k8_3 = (k + 3) / 8, dk_3 = (k + 3) % 8;
- __m256 a0 = _mm256_set1_ps(A[i * n + k]);
- __m256 a1 = _mm256_set1_ps(A[i * n + k + 1]);
- __m256 a2 = _mm256_set1_ps(A[i * n + k + 2]);
- __m256 a3 = _mm256_set1_ps(A[i * n + k + 3]);
+ __m256 a0 = _mm256_set1_ps(A[rn + k]);
+ __m256 a1 = _mm256_set1_ps(A[rn + k + 1]);
+ __m256 a2 = _mm256_set1_ps(A[rn + k + 2]);
+ __m256 a3 = _mm256_set1_ps(A[rn + k + 3]);
__m256 b0 = _mm256_load_ps(&B_packed[(j8idx * n8 + k8_0) * 64 + dk_0 * 8]);
__m256 b1 = _mm256_load_ps(&B_packed[(j8idx * n8 + k8_1) * 64 + dk_1 * 8]);
__m256 b2 = _mm256_load_ps(&B_packed[(j8idx * n8 + k8_2) * 64 + dk_2 * 8]);
@@ -361,24 +498,26 @@ int matmul_avx2_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const
acc2 = _mm256_fmadd_ps(a2, b2, acc2);
acc3 = _mm256_fmadd_ps(a3, b3, acc3);
}
- acc0 = _mm256_add_ps(acc0, acc1);
- acc2 = _mm256_add_ps(acc2, acc3);
- acc0 = _mm256_add_ps(acc0, acc2);
for (; k < k_end; k++) {
size_t k8 = k / 8, dk = k % 8;
- __m256 a_bcast = _mm256_set1_ps(A[i * n + k]);
+ __m256 a_bcast = _mm256_set1_ps(A[rn + k]);
__m256 b_val = _mm256_load_ps(&B_packed[(j8idx * n8 + k8) * 64 + dk * 8]);
acc0 = _mm256_fmadd_ps(a_bcast, b_val, acc0);
}
- float tmp[8] __attribute__((aligned(32)));
- _mm256_store_ps(tmp, acc0);
- for (size_t dj = 0; dj < 8; dj++) acc[li * tj + lj + dj] += tmp[dj];
}
- for (size_t j = jj + (tj / 8) * 8; j < j_end; j++) {
- size_t lj = j - jj;
- for (size_t k = kk; k < k_end; k++) {
- acc[li * tj + lj] += (double)A[i * n + k] * (double)B[k * p + j];
- }
+
+ acc0 = _mm256_add_ps(acc0, acc1);
+ acc2 = _mm256_add_ps(acc2, acc3);
+ acc0 = _mm256_add_ps(acc0, acc2);
+
+ float tmp[8] __attribute__((aligned(32)));
+ _mm256_store_ps(tmp, acc0);
+ for (size_t dj = 0; dj < 8; dj++) acc[li * tj + lj + dj] += tmp[dj];
+ }
+ for (size_t j = jj + (tj / 8) * 8; j < j_end; j++) {
+ size_t lj = j - jj;
+ for (size_t k = 0; k < n; k++) {
+ acc[li * tj + lj] += (double)A[rn + k] * (double)B[k * p + j];
}
}
}
@@ -386,10 +525,8 @@ int matmul_avx2_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const
for (size_t i = ii; i < i_end; i++) {
size_t li = i - ii;
for (size_t j = jj; j < j_end; j++) {
- size_t lj = j - jj;
- float v = acc[li * tj + lj];
- if (scale > 1.0) v /= (float)scale;
- C[i * p + j] = v;
+ size_t lj = j - jj;
+ C[i * p + j] = acc[li * tj + lj] * inv_scale;
}
}
}
@@ -401,8 +538,7 @@ int matmul_avx2_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const
for (size_t k = 0; k < n; k++) {
sum += (double)A[i * n + k] * (double)B[k * p + j];
}
- if (scale > 1.0) sum /= scale;
- C[i * p + j] = (float)sum;
+ C[i * p + j] = (float)(sum * inv_scale);
}
}
@@ -441,6 +577,8 @@ int matmul_avx512_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, cons
if (posix_memalign((void **)&B_packed, 64, p16 * n16 * 256 * sizeof(float)) != 0) return -1;
pack_b_f32_512(n, p, B, B_packed);
+ float inv_scale = (scale > 1.0) ? 1.0f / (float)scale : 1.0f;
+
#pragma omp parallel for schedule(static)
for (size_t ii = 0; ii < m; ii += ib) {
size_t i_end = (ii + ib < m) ? ii + ib : m;
@@ -451,28 +589,163 @@ int matmul_avx512_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, cons
float acc[64 * 64];
memset(acc, 0, ti * tj * sizeof(float));
- for (size_t kk = 0; kk < n; kk += kb) {
- size_t k_end = (kk + kb < n) ? kk + kb : n;
- size_t k_end16 = kk + (k_end - kk) / 16 * 16;
- for (size_t i = ii; i < i_end; i++) {
- size_t li = i - ii;
- for (size_t j = jj; j + 16 <= j_end; j += 16) {
- size_t lj = j - jj;
- size_t j16idx = j / 16;
- __m512 acc0 = _mm512_setzero_ps();
- __m512 acc1 = _mm512_setzero_ps();
- __m512 acc2 = _mm512_setzero_ps();
- __m512 acc3 = _mm512_setzero_ps();
- size_t k = kk;
+ for (size_t i4 = ii; i4 + 4 <= i_end; i4 += 4) {
+ size_t li0 = i4 - ii;
+ size_t li1 = li0 + 1;
+ size_t li2 = li0 + 2;
+ size_t li3 = li0 + 3;
+ size_t rn0 = i4 * n;
+ size_t rn1 = (i4 + 1) * n;
+ size_t rn2 = (i4 + 2) * n;
+ size_t rn3 = (i4 + 3) * n;
+
+ for (size_t j = jj; j + 16 <= j_end; j += 16) {
+ size_t lj = j - jj;
+ size_t j16idx = j / 16;
+ __m512 acc00 = _mm512_setzero_ps();
+ __m512 acc01 = _mm512_setzero_ps();
+ __m512 acc02 = _mm512_setzero_ps();
+ __m512 acc03 = _mm512_setzero_ps();
+ __m512 acc10 = _mm512_setzero_ps();
+ __m512 acc11 = _mm512_setzero_ps();
+ __m512 acc12 = _mm512_setzero_ps();
+ __m512 acc13 = _mm512_setzero_ps();
+ __m512 acc20 = _mm512_setzero_ps();
+ __m512 acc21 = _mm512_setzero_ps();
+ __m512 acc22 = _mm512_setzero_ps();
+ __m512 acc23 = _mm512_setzero_ps();
+ __m512 acc30 = _mm512_setzero_ps();
+ __m512 acc31 = _mm512_setzero_ps();
+ __m512 acc32 = _mm512_setzero_ps();
+ __m512 acc33 = _mm512_setzero_ps();
+
+ for (size_t kk = 0; kk < n; kk += kb) {
+ size_t k_end = (kk + kb < n) ? kk + kb : n;
+ size_t k_end16 = kk + (k_end - kk) / 16 * 16;
+ size_t k = kk;
+ for (; k + 4 <= k_end16; k += 4) {
+ size_t k16_0 = k / 16, dk_0 = k % 16;
+ size_t k16_1 = (k + 1) / 16, dk_1 = (k + 1) % 16;
+ size_t k16_2 = (k + 2) / 16, dk_2 = (k + 2) % 16;
+ size_t k16_3 = (k + 3) / 16, dk_3 = (k + 3) % 16;
+ __m512 a0 = _mm512_set1_ps(A[rn0 + k]);
+ __m512 a1 = _mm512_set1_ps(A[rn0 + k + 1]);
+ __m512 a2 = _mm512_set1_ps(A[rn0 + k + 2]);
+ __m512 a3 = _mm512_set1_ps(A[rn0 + k + 3]);
+ __m512 b0 = _mm512_load_ps(&B_packed[(j16idx * n16 + k16_0) * 256 + dk_0 * 16]);
+ __m512 b1 = _mm512_load_ps(&B_packed[(j16idx * n16 + k16_1) * 256 + dk_1 * 16]);
+ __m512 b2 = _mm512_load_ps(&B_packed[(j16idx * n16 + k16_2) * 256 + dk_2 * 16]);
+ __m512 b3 = _mm512_load_ps(&B_packed[(j16idx * n16 + k16_3) * 256 + dk_3 * 16]);
+ acc00 = _mm512_fmadd_ps(a0, b0, acc00);
+ acc01 = _mm512_fmadd_ps(a1, b1, acc01);
+ acc02 = _mm512_fmadd_ps(a2, b2, acc02);
+ acc03 = _mm512_fmadd_ps(a3, b3, acc03);
+ a0 = _mm512_set1_ps(A[rn1 + k]);
+ a1 = _mm512_set1_ps(A[rn1 + k + 1]);
+ a2 = _mm512_set1_ps(A[rn1 + k + 2]);
+ a3 = _mm512_set1_ps(A[rn1 + k + 3]);
+ acc10 = _mm512_fmadd_ps(a0, b0, acc10);
+ acc11 = _mm512_fmadd_ps(a1, b1, acc11);
+ acc12 = _mm512_fmadd_ps(a2, b2, acc12);
+ acc13 = _mm512_fmadd_ps(a3, b3, acc13);
+ a0 = _mm512_set1_ps(A[rn2 + k]);
+ a1 = _mm512_set1_ps(A[rn2 + k + 1]);
+ a2 = _mm512_set1_ps(A[rn2 + k + 2]);
+ a3 = _mm512_set1_ps(A[rn2 + k + 3]);
+ acc20 = _mm512_fmadd_ps(a0, b0, acc20);
+ acc21 = _mm512_fmadd_ps(a1, b1, acc21);
+ acc22 = _mm512_fmadd_ps(a2, b2, acc22);
+ acc23 = _mm512_fmadd_ps(a3, b3, acc23);
+ a0 = _mm512_set1_ps(A[rn3 + k]);
+ a1 = _mm512_set1_ps(A[rn3 + k + 1]);
+ a2 = _mm512_set1_ps(A[rn3 + k + 2]);
+ a3 = _mm512_set1_ps(A[rn3 + k + 3]);
+ acc30 = _mm512_fmadd_ps(a0, b0, acc30);
+ acc31 = _mm512_fmadd_ps(a1, b1, acc31);
+ acc32 = _mm512_fmadd_ps(a2, b2, acc32);
+ acc33 = _mm512_fmadd_ps(a3, b3, acc33);
+ }
+ for (; k < k_end; k++) {
+ size_t k16 = k / 16, dk = k % 16;
+ __m512 a_bcast = _mm512_set1_ps(A[rn0 + k]);
+ __m512 b_val = _mm512_load_ps(&B_packed[(j16idx * n16 + k16) * 256 + dk * 16]);
+ acc00 = _mm512_fmadd_ps(a_bcast, b_val, acc00);
+ a_bcast = _mm512_set1_ps(A[rn1 + k]);
+ acc10 = _mm512_fmadd_ps(a_bcast, b_val, acc10);
+ a_bcast = _mm512_set1_ps(A[rn2 + k]);
+ acc20 = _mm512_fmadd_ps(a_bcast, b_val, acc20);
+ a_bcast = _mm512_set1_ps(A[rn3 + k]);
+ acc30 = _mm512_fmadd_ps(a_bcast, b_val, acc30);
+ }
+ }
+
+ acc00 = _mm512_add_ps(acc00, acc01);
+ acc02 = _mm512_add_ps(acc02, acc03);
+ acc00 = _mm512_add_ps(acc00, acc02);
+ acc10 = _mm512_add_ps(acc10, acc11);
+ acc12 = _mm512_add_ps(acc12, acc13);
+ acc10 = _mm512_add_ps(acc10, acc12);
+ acc20 = _mm512_add_ps(acc20, acc21);
+ acc22 = _mm512_add_ps(acc22, acc23);
+ acc20 = _mm512_add_ps(acc20, acc22);
+ acc30 = _mm512_add_ps(acc30, acc31);
+ acc32 = _mm512_add_ps(acc32, acc33);
+ acc30 = _mm512_add_ps(acc30, acc32);
+
+ float tmp[16] __attribute__((aligned(64)));
+ _mm512_store_ps(tmp, acc00);
+ for (size_t dj = 0; dj < 16; dj++) acc[li0 * tj + lj + dj] += tmp[dj];
+ _mm512_store_ps(tmp, acc10);
+ for (size_t dj = 0; dj < 16; dj++) acc[li1 * tj + lj + dj] += tmp[dj];
+ _mm512_store_ps(tmp, acc20);
+ for (size_t dj = 0; dj < 16; dj++) acc[li2 * tj + lj + dj] += tmp[dj];
+ _mm512_store_ps(tmp, acc30);
+ for (size_t dj = 0; dj < 16; dj++) acc[li3 * tj + lj + dj] += tmp[dj];
+ }
+
+ for (size_t j = jj + (tj / 16) * 16; j < j_end; j++) {
+ size_t lj = j - jj;
+ double sum0 = 0.0;
+ double sum1 = 0.0;
+ double sum2 = 0.0;
+ double sum3 = 0.0;
+ for (size_t k = 0; k < n; k++) {
+ sum0 += (double)A[rn0 + k] * (double)B[k * p + j];
+ sum1 += (double)A[rn1 + k] * (double)B[k * p + j];
+ sum2 += (double)A[rn2 + k] * (double)B[k * p + j];
+ sum3 += (double)A[rn3 + k] * (double)B[k * p + j];
+ }
+ acc[li0 * tj + lj] += (float)sum0;
+ acc[li1 * tj + lj] += (float)sum1;
+ acc[li2 * tj + lj] += (float)sum2;
+ acc[li3 * tj + lj] += (float)sum3;
+ }
+ }
+
+ for (size_t i = ii + (i_end - ii) / 4 * 4; i < i_end; i++) {
+ size_t li = i - ii;
+ size_t rn = i * n;
+ for (size_t j = jj; j + 16 <= j_end; j += 16) {
+ size_t lj = j - jj;
+ size_t j16idx = j / 16;
+ __m512 acc0 = _mm512_setzero_ps();
+ __m512 acc1 = _mm512_setzero_ps();
+ __m512 acc2 = _mm512_setzero_ps();
+ __m512 acc3 = _mm512_setzero_ps();
+
+ for (size_t kk = 0; kk < n; kk += kb) {
+ size_t k_end = (kk + kb < n) ? kk + kb : n;
+ size_t k_end16 = kk + (k_end - kk) / 16 * 16;
+ size_t k = kk;
for (; k + 4 <= k_end16; k += 4) {
size_t k16_0 = k / 16, dk_0 = k % 16;
size_t k16_1 = (k + 1) / 16, dk_1 = (k + 1) % 16;
size_t k16_2 = (k + 2) / 16, dk_2 = (k + 2) % 16;
size_t k16_3 = (k + 3) / 16, dk_3 = (k + 3) % 16;
- __m512 a0 = _mm512_set1_ps(A[i * n + k]);
- __m512 a1 = _mm512_set1_ps(A[i * n + k + 1]);
- __m512 a2 = _mm512_set1_ps(A[i * n + k + 2]);
- __m512 a3 = _mm512_set1_ps(A[i * n + k + 3]);
+ __m512 a0 = _mm512_set1_ps(A[rn + k]);
+ __m512 a1 = _mm512_set1_ps(A[rn + k + 1]);
+ __m512 a2 = _mm512_set1_ps(A[rn + k + 2]);
+ __m512 a3 = _mm512_set1_ps(A[rn + k + 3]);
__m512 b0 = _mm512_load_ps(&B_packed[(j16idx * n16 + k16_0) * 256 + dk_0 * 16]);
__m512 b1 = _mm512_load_ps(&B_packed[(j16idx * n16 + k16_1) * 256 + dk_1 * 16]);
__m512 b2 = _mm512_load_ps(&B_packed[(j16idx * n16 + k16_2) * 256 + dk_2 * 16]);
@@ -482,24 +755,26 @@ int matmul_avx512_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, cons
acc2 = _mm512_fmadd_ps(a2, b2, acc2);
acc3 = _mm512_fmadd_ps(a3, b3, acc3);
}
- acc0 = _mm512_add_ps(acc0, acc1);
- acc2 = _mm512_add_ps(acc2, acc3);
- acc0 = _mm512_add_ps(acc0, acc2);
for (; k < k_end; k++) {
size_t k16 = k / 16, dk = k % 16;
- __m512 a_bcast = _mm512_set1_ps(A[i * n + k]);
+ __m512 a_bcast = _mm512_set1_ps(A[rn + k]);
__m512 b_val = _mm512_load_ps(&B_packed[(j16idx * n16 + k16) * 256 + dk * 16]);
acc0 = _mm512_fmadd_ps(a_bcast, b_val, acc0);
}
- float tmp[16] __attribute__((aligned(64)));
- _mm512_store_ps(tmp, acc0);
- for (size_t dj = 0; dj < 16; dj++) acc[li * tj + lj + dj] += tmp[dj];
}
- for (size_t j = jj + (tj / 16) * 16; j < j_end; j++) {
- size_t lj = j - jj;
- for (size_t k = kk; k < k_end; k++) {
- acc[li * tj + lj] += (double)A[i * n + k] * (double)B[k * p + j];
- }
+
+ acc0 = _mm512_add_ps(acc0, acc1);
+ acc2 = _mm512_add_ps(acc2, acc3);
+ acc0 = _mm512_add_ps(acc0, acc2);
+
+ float tmp[16] __attribute__((aligned(64)));
+ _mm512_store_ps(tmp, acc0);
+ for (size_t dj = 0; dj < 16; dj++) acc[li * tj + lj + dj] += tmp[dj];
+ }
+ for (size_t j = jj + (tj / 16) * 16; j < j_end; j++) {
+ size_t lj = j - jj;
+ for (size_t k = 0; k < n; k++) {
+ acc[li * tj + lj] += (double)A[rn + k] * (double)B[k * p + j];
}
}
}
@@ -507,10 +782,8 @@ int matmul_avx512_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, cons
for (size_t i = ii; i < i_end; i++) {
size_t li = i - ii;
for (size_t j = jj; j < j_end; j++) {
- size_t lj = j - jj;
- float v = acc[li * tj + lj];
- if (scale > 1.0) v /= (float)scale;
- C[i * p + j] = v;
+ size_t lj = j - jj;
+ C[i * p + j] = acc[li * tj + lj] * inv_scale;
}
}
}
@@ -522,8 +795,7 @@ int matmul_avx512_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, cons
for (size_t k = 0; k < n; k++) {
sum += (double)A[i * n + k] * (double)B[k * p + j];
}
- if (scale > 1.0) sum /= scale;
- C[i * p + j] = (float)sum;
+ C[i * p + j] = (float)(sum * inv_scale);
}
}
@@ -632,6 +904,8 @@ int matmul_avx2_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, const
if (posix_memalign((void **)&B_packed, 64, p4 * n4 * 16 * sizeof(double)) != 0) return -1;
pack_b_f64(n, p, B, B_packed);
+ double inv_scale = (scale > 1.0) ? 1.0 / scale : 1.0;
+
#pragma omp parallel for schedule(static)
for (size_t ii = 0; ii < m; ii += ib) {
size_t i_end = (ii + ib < m) ? ii + ib : m;
@@ -642,28 +916,163 @@ int matmul_avx2_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, const
double acc[64 * 64];
memset(acc, 0, ti * tj * sizeof(double));
- for (size_t kk = 0; kk < n; kk += kb) {
- size_t k_end = (kk + kb < n) ? kk + kb : n;
- size_t k_end4 = kk + (k_end - kk) / 4 * 4;
- for (size_t i = ii; i < i_end; i++) {
- size_t li = i - ii;
- for (size_t j = jj; j + 4 <= j_end; j += 4) {
- size_t lj = j - jj;
- size_t j4idx = j / 4;
- __m256d acc0 = _mm256_setzero_pd();
- __m256d acc1 = _mm256_setzero_pd();
- __m256d acc2 = _mm256_setzero_pd();
- __m256d acc3 = _mm256_setzero_pd();
- size_t k = kk;
+ for (size_t i4 = ii; i4 + 4 <= i_end; i4 += 4) {
+ size_t li0 = i4 - ii;
+ size_t li1 = li0 + 1;
+ size_t li2 = li0 + 2;
+ size_t li3 = li0 + 3;
+ size_t rn0 = i4 * n;
+ size_t rn1 = (i4 + 1) * n;
+ size_t rn2 = (i4 + 2) * n;
+ size_t rn3 = (i4 + 3) * n;
+
+ for (size_t j = jj; j + 4 <= j_end; j += 4) {
+ size_t lj = j - jj;
+ size_t j4idx = j / 4;
+ __m256d acc00 = _mm256_setzero_pd();
+ __m256d acc01 = _mm256_setzero_pd();
+ __m256d acc02 = _mm256_setzero_pd();
+ __m256d acc03 = _mm256_setzero_pd();
+ __m256d acc10 = _mm256_setzero_pd();
+ __m256d acc11 = _mm256_setzero_pd();
+ __m256d acc12 = _mm256_setzero_pd();
+ __m256d acc13 = _mm256_setzero_pd();
+ __m256d acc20 = _mm256_setzero_pd();
+ __m256d acc21 = _mm256_setzero_pd();
+ __m256d acc22 = _mm256_setzero_pd();
+ __m256d acc23 = _mm256_setzero_pd();
+ __m256d acc30 = _mm256_setzero_pd();
+ __m256d acc31 = _mm256_setzero_pd();
+ __m256d acc32 = _mm256_setzero_pd();
+ __m256d acc33 = _mm256_setzero_pd();
+
+ for (size_t kk = 0; kk < n; kk += kb) {
+ size_t k_end = (kk + kb < n) ? kk + kb : n;
+ size_t k_end4 = kk + (k_end - kk) / 4 * 4;
+ size_t k = kk;
+ for (; k + 4 <= k_end4; k += 4) {
+ size_t k4_0 = k / 4, dk_0 = k % 4;
+ size_t k4_1 = (k + 1) / 4, dk_1 = (k + 1) % 4;
+ size_t k4_2 = (k + 2) / 4, dk_2 = (k + 2) % 4;
+ size_t k4_3 = (k + 3) / 4, dk_3 = (k + 3) % 4;
+ __m256d a0 = _mm256_set1_pd(A[rn0 + k]);
+ __m256d a1 = _mm256_set1_pd(A[rn0 + k + 1]);
+ __m256d a2 = _mm256_set1_pd(A[rn0 + k + 2]);
+ __m256d a3 = _mm256_set1_pd(A[rn0 + k + 3]);
+ __m256d b0 = _mm256_load_pd(&B_packed[(j4idx * n4 + k4_0) * 16 + dk_0 * 4]);
+ __m256d b1 = _mm256_load_pd(&B_packed[(j4idx * n4 + k4_1) * 16 + dk_1 * 4]);
+ __m256d b2 = _mm256_load_pd(&B_packed[(j4idx * n4 + k4_2) * 16 + dk_2 * 4]);
+ __m256d b3 = _mm256_load_pd(&B_packed[(j4idx * n4 + k4_3) * 16 + dk_3 * 4]);
+ acc00 = _mm256_fmadd_pd(a0, b0, acc00);
+ acc01 = _mm256_fmadd_pd(a1, b1, acc01);
+ acc02 = _mm256_fmadd_pd(a2, b2, acc02);
+ acc03 = _mm256_fmadd_pd(a3, b3, acc03);
+ a0 = _mm256_set1_pd(A[rn1 + k]);
+ a1 = _mm256_set1_pd(A[rn1 + k + 1]);
+ a2 = _mm256_set1_pd(A[rn1 + k + 2]);
+ a3 = _mm256_set1_pd(A[rn1 + k + 3]);
+ acc10 = _mm256_fmadd_pd(a0, b0, acc10);
+ acc11 = _mm256_fmadd_pd(a1, b1, acc11);
+ acc12 = _mm256_fmadd_pd(a2, b2, acc12);
+ acc13 = _mm256_fmadd_pd(a3, b3, acc13);
+ a0 = _mm256_set1_pd(A[rn2 + k]);
+ a1 = _mm256_set1_pd(A[rn2 + k + 1]);
+ a2 = _mm256_set1_pd(A[rn2 + k + 2]);
+ a3 = _mm256_set1_pd(A[rn2 + k + 3]);
+ acc20 = _mm256_fmadd_pd(a0, b0, acc20);
+ acc21 = _mm256_fmadd_pd(a1, b1, acc21);
+ acc22 = _mm256_fmadd_pd(a2, b2, acc22);
+ acc23 = _mm256_fmadd_pd(a3, b3, acc23);
+ a0 = _mm256_set1_pd(A[rn3 + k]);
+ a1 = _mm256_set1_pd(A[rn3 + k + 1]);
+ a2 = _mm256_set1_pd(A[rn3 + k + 2]);
+ a3 = _mm256_set1_pd(A[rn3 + k + 3]);
+ acc30 = _mm256_fmadd_pd(a0, b0, acc30);
+ acc31 = _mm256_fmadd_pd(a1, b1, acc31);
+ acc32 = _mm256_fmadd_pd(a2, b2, acc32);
+ acc33 = _mm256_fmadd_pd(a3, b3, acc33);
+ }
+ for (; k < k_end; k++) {
+ size_t k4 = k / 4, dk = k % 4;
+ __m256d a_bcast = _mm256_set1_pd(A[rn0 + k]);
+ __m256d b_val = _mm256_load_pd(&B_packed[(j4idx * n4 + k4) * 16 + dk * 4]);
+ acc00 = _mm256_fmadd_pd(a_bcast, b_val, acc00);
+ a_bcast = _mm256_set1_pd(A[rn1 + k]);
+ acc10 = _mm256_fmadd_pd(a_bcast, b_val, acc10);
+ a_bcast = _mm256_set1_pd(A[rn2 + k]);
+ acc20 = _mm256_fmadd_pd(a_bcast, b_val, acc20);
+ a_bcast = _mm256_set1_pd(A[rn3 + k]);
+ acc30 = _mm256_fmadd_pd(a_bcast, b_val, acc30);
+ }
+ }
+
+ acc00 = _mm256_add_pd(acc00, acc01);
+ acc02 = _mm256_add_pd(acc02, acc03);
+ acc00 = _mm256_add_pd(acc00, acc02);
+ acc10 = _mm256_add_pd(acc10, acc11);
+ acc12 = _mm256_add_pd(acc12, acc13);
+ acc10 = _mm256_add_pd(acc10, acc12);
+ acc20 = _mm256_add_pd(acc20, acc21);
+ acc22 = _mm256_add_pd(acc22, acc23);
+ acc20 = _mm256_add_pd(acc20, acc22);
+ acc30 = _mm256_add_pd(acc30, acc31);
+ acc32 = _mm256_add_pd(acc32, acc33);
+ acc30 = _mm256_add_pd(acc30, acc32);
+
+ double tmp[4] __attribute__((aligned(32)));
+ _mm256_store_pd(tmp, acc00);
+ for (size_t dj = 0; dj < 4; dj++) acc[li0 * tj + lj + dj] += tmp[dj];
+ _mm256_store_pd(tmp, acc10);
+ for (size_t dj = 0; dj < 4; dj++) acc[li1 * tj + lj + dj] += tmp[dj];
+ _mm256_store_pd(tmp, acc20);
+ for (size_t dj = 0; dj < 4; dj++) acc[li2 * tj + lj + dj] += tmp[dj];
+ _mm256_store_pd(tmp, acc30);
+ for (size_t dj = 0; dj < 4; dj++) acc[li3 * tj + lj + dj] += tmp[dj];
+ }
+
+ for (size_t j = jj + (tj / 4) * 4; j < j_end; j++) {
+ size_t lj = j - jj;
+ double sum0 = 0.0;
+ double sum1 = 0.0;
+ double sum2 = 0.0;
+ double sum3 = 0.0;
+ for (size_t k = 0; k < n; k++) {
+ sum0 += A[rn0 + k] * B[k * p + j];
+ sum1 += A[rn1 + k] * B[k * p + j];
+ sum2 += A[rn2 + k] * B[k * p + j];
+ sum3 += A[rn3 + k] * B[k * p + j];
+ }
+ acc[li0 * tj + lj] += sum0;
+ acc[li1 * tj + lj] += sum1;
+ acc[li2 * tj + lj] += sum2;
+ acc[li3 * tj + lj] += sum3;
+ }
+ }
+
+ for (size_t i = ii + (i_end - ii) / 4 * 4; i < i_end; i++) {
+ size_t li = i - ii;
+ size_t rn = i * n;
+ for (size_t j = jj; j + 4 <= j_end; j += 4) {
+ size_t lj = j - jj;
+ size_t j4idx = j / 4;
+ __m256d acc0 = _mm256_setzero_pd();
+ __m256d acc1 = _mm256_setzero_pd();
+ __m256d acc2 = _mm256_setzero_pd();
+ __m256d acc3 = _mm256_setzero_pd();
+
+ for (size_t kk = 0; kk < n; kk += kb) {
+ size_t k_end = (kk + kb < n) ? kk + kb : n;
+ size_t k_end4 = kk + (k_end - kk) / 4 * 4;
+ size_t k = kk;
for (; k + 4 <= k_end4; k += 4) {
size_t k4_0 = k / 4, dk_0 = k % 4;
size_t k4_1 = (k + 1) / 4, dk_1 = (k + 1) % 4;
size_t k4_2 = (k + 2) / 4, dk_2 = (k + 2) % 4;
size_t k4_3 = (k + 3) / 4, dk_3 = (k + 3) % 4;
- __m256d a0 = _mm256_set1_pd(A[i * n + k]);
- __m256d a1 = _mm256_set1_pd(A[i * n + k + 1]);
- __m256d a2 = _mm256_set1_pd(A[i * n + k + 2]);
- __m256d a3 = _mm256_set1_pd(A[i * n + k + 3]);
+ __m256d a0 = _mm256_set1_pd(A[rn + k]);
+ __m256d a1 = _mm256_set1_pd(A[rn + k + 1]);
+ __m256d a2 = _mm256_set1_pd(A[rn + k + 2]);
+ __m256d a3 = _mm256_set1_pd(A[rn + k + 3]);
__m256d b0 = _mm256_load_pd(&B_packed[(j4idx * n4 + k4_0) * 16 + dk_0 * 4]);
__m256d b1 = _mm256_load_pd(&B_packed[(j4idx * n4 + k4_1) * 16 + dk_1 * 4]);
__m256d b2 = _mm256_load_pd(&B_packed[(j4idx * n4 + k4_2) * 16 + dk_2 * 4]);
@@ -673,24 +1082,26 @@ int matmul_avx2_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, const
acc2 = _mm256_fmadd_pd(a2, b2, acc2);
acc3 = _mm256_fmadd_pd(a3, b3, acc3);
}
- acc0 = _mm256_add_pd(acc0, acc1);
- acc2 = _mm256_add_pd(acc2, acc3);
- acc0 = _mm256_add_pd(acc0, acc2);
for (; k < k_end; k++) {
size_t k4 = k / 4, dk = k % 4;
- __m256d a_bcast = _mm256_set1_pd(A[i * n + k]);
+ __m256d a_bcast = _mm256_set1_pd(A[rn + k]);
__m256d b_val = _mm256_load_pd(&B_packed[(j4idx * n4 + k4) * 16 + dk * 4]);
acc0 = _mm256_fmadd_pd(a_bcast, b_val, acc0);
}
- double tmp[4] __attribute__((aligned(32)));
- _mm256_store_pd(tmp, acc0);
- for (size_t dj = 0; dj < 4; dj++) acc[li * tj + lj + dj] += tmp[dj];
}
- for (size_t j = jj + (tj / 4) * 4; j < j_end; j++) {
- size_t lj = j - jj;
- for (size_t k = kk; k < k_end; k++) {
- acc[li * tj + lj] += A[i * n + k] * B[k * p + j];
- }
+
+ acc0 = _mm256_add_pd(acc0, acc1);
+ acc2 = _mm256_add_pd(acc2, acc3);
+ acc0 = _mm256_add_pd(acc0, acc2);
+
+ double tmp[4] __attribute__((aligned(32)));
+ _mm256_store_pd(tmp, acc0);
+ for (size_t dj = 0; dj < 4; dj++) acc[li * tj + lj + dj] += tmp[dj];
+ }
+ for (size_t j = jj + (tj / 4) * 4; j < j_end; j++) {
+ size_t lj = j - jj;
+ for (size_t k = 0; k < n; k++) {
+ acc[li * tj + lj] += A[rn + k] * B[k * p + j];
}
}
}
@@ -698,10 +1109,8 @@ int matmul_avx2_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, const
for (size_t i = ii; i < i_end; i++) {
size_t li = i - ii;
for (size_t j = jj; j < j_end; j++) {
- size_t lj = j - jj;
- double v = acc[li * tj + lj];
- if (scale > 1.0) v /= scale;
- C[i * p + j] = v;
+ size_t lj = j - jj;
+ C[i * p + j] = acc[li * tj + lj] * inv_scale;
}
}
}
@@ -713,8 +1122,7 @@ int matmul_avx2_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, const
for (size_t k = 0; k < n; k++) {
sum += A[i * n + k] * B[k * p + j];
}
- if (scale > 1.0) sum /= scale;
- C[i * p + j] = sum;
+ C[i * p + j] = sum * inv_scale;
}
}
@@ -753,6 +1161,8 @@ int matmul_avx512_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, con
if (posix_memalign((void **)&B_packed, 64, p8 * n8 * 64 * sizeof(double)) != 0) return -1;
pack_b_f64_512(n, p, B, B_packed);
+ double inv_scale = (scale > 1.0) ? 1.0 / scale : 1.0;
+
#pragma omp parallel for schedule(static)
for (size_t ii = 0; ii < m; ii += ib) {
size_t i_end = (ii + ib < m) ? ii + ib : m;
@@ -763,28 +1173,163 @@ int matmul_avx512_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, con
double acc[64 * 64];
memset(acc, 0, ti * tj * sizeof(double));
- for (size_t kk = 0; kk < n; kk += kb) {
- size_t k_end = (kk + kb < n) ? kk + kb : n;
- size_t k_end8 = kk + (k_end - kk) / 8 * 8;
- for (size_t i = ii; i < i_end; i++) {
- size_t li = i - ii;
- for (size_t j = jj; j + 8 <= j_end; j += 8) {
- size_t lj = j - jj;
- size_t j8idx = j / 8;
- __m512d acc0 = _mm512_setzero_pd();
- __m512d acc1 = _mm512_setzero_pd();
- __m512d acc2 = _mm512_setzero_pd();
- __m512d acc3 = _mm512_setzero_pd();
- size_t k = kk;
+ for (size_t i4 = ii; i4 + 4 <= i_end; i4 += 4) {
+ size_t li0 = i4 - ii;
+ size_t li1 = li0 + 1;
+ size_t li2 = li0 + 2;
+ size_t li3 = li0 + 3;
+ size_t rn0 = i4 * n;
+ size_t rn1 = (i4 + 1) * n;
+ size_t rn2 = (i4 + 2) * n;
+ size_t rn3 = (i4 + 3) * n;
+
+ for (size_t j = jj; j + 8 <= j_end; j += 8) {
+ size_t lj = j - jj;
+ size_t j8idx = j / 8;
+ __m512d acc00 = _mm512_setzero_pd();
+ __m512d acc01 = _mm512_setzero_pd();
+ __m512d acc02 = _mm512_setzero_pd();
+ __m512d acc03 = _mm512_setzero_pd();
+ __m512d acc10 = _mm512_setzero_pd();
+ __m512d acc11 = _mm512_setzero_pd();
+ __m512d acc12 = _mm512_setzero_pd();
+ __m512d acc13 = _mm512_setzero_pd();
+ __m512d acc20 = _mm512_setzero_pd();
+ __m512d acc21 = _mm512_setzero_pd();
+ __m512d acc22 = _mm512_setzero_pd();
+ __m512d acc23 = _mm512_setzero_pd();
+ __m512d acc30 = _mm512_setzero_pd();
+ __m512d acc31 = _mm512_setzero_pd();
+ __m512d acc32 = _mm512_setzero_pd();
+ __m512d acc33 = _mm512_setzero_pd();
+
+ for (size_t kk = 0; kk < n; kk += kb) {
+ size_t k_end = (kk + kb < n) ? kk + kb : n;
+ size_t k_end8 = kk + (k_end - kk) / 8 * 8;
+ size_t k = kk;
for (; k + 4 <= k_end8; k += 4) {
size_t k8_0 = k / 8, dk_0 = k % 8;
size_t k8_1 = (k + 1) / 8, dk_1 = (k + 1) % 8;
size_t k8_2 = (k + 2) / 8, dk_2 = (k + 2) % 8;
size_t k8_3 = (k + 3) / 8, dk_3 = (k + 3) % 8;
- __m512d a0 = _mm512_set1_pd(A[i * n + k]);
- __m512d a1 = _mm512_set1_pd(A[i * n + k + 1]);
- __m512d a2 = _mm512_set1_pd(A[i * n + k + 2]);
- __m512d a3 = _mm512_set1_pd(A[i * n + k + 3]);
+ __m512d a0 = _mm512_set1_pd(A[rn0 + k]);
+ __m512d a1 = _mm512_set1_pd(A[rn0 + k + 1]);
+ __m512d a2 = _mm512_set1_pd(A[rn0 + k + 2]);
+ __m512d a3 = _mm512_set1_pd(A[rn0 + k + 3]);
+ __m512d b0 = _mm512_load_pd(&B_packed[(j8idx * n8 + k8_0) * 64 + dk_0 * 8]);
+ __m512d b1 = _mm512_load_pd(&B_packed[(j8idx * n8 + k8_1) * 64 + dk_1 * 8]);
+ __m512d b2 = _mm512_load_pd(&B_packed[(j8idx * n8 + k8_2) * 64 + dk_2 * 8]);
+ __m512d b3 = _mm512_load_pd(&B_packed[(j8idx * n8 + k8_3) * 64 + dk_3 * 8]);
+ acc00 = _mm512_fmadd_pd(a0, b0, acc00);
+ acc01 = _mm512_fmadd_pd(a1, b1, acc01);
+ acc02 = _mm512_fmadd_pd(a2, b2, acc02);
+ acc03 = _mm512_fmadd_pd(a3, b3, acc03);
+ a0 = _mm512_set1_pd(A[rn1 + k]);
+ a1 = _mm512_set1_pd(A[rn1 + k + 1]);
+ a2 = _mm512_set1_pd(A[rn1 + k + 2]);
+ a3 = _mm512_set1_pd(A[rn1 + k + 3]);
+ acc10 = _mm512_fmadd_pd(a0, b0, acc10);
+ acc11 = _mm512_fmadd_pd(a1, b1, acc11);
+ acc12 = _mm512_fmadd_pd(a2, b2, acc12);
+ acc13 = _mm512_fmadd_pd(a3, b3, acc13);
+ a0 = _mm512_set1_pd(A[rn2 + k]);
+ a1 = _mm512_set1_pd(A[rn2 + k + 1]);
+ a2 = _mm512_set1_pd(A[rn2 + k + 2]);
+ a3 = _mm512_set1_pd(A[rn2 + k + 3]);
+ acc20 = _mm512_fmadd_pd(a0, b0, acc20);
+ acc21 = _mm512_fmadd_pd(a1, b1, acc21);
+ acc22 = _mm512_fmadd_pd(a2, b2, acc22);
+ acc23 = _mm512_fmadd_pd(a3, b3, acc23);
+ a0 = _mm512_set1_pd(A[rn3 + k]);
+ a1 = _mm512_set1_pd(A[rn3 + k + 1]);
+ a2 = _mm512_set1_pd(A[rn3 + k + 2]);
+ a3 = _mm512_set1_pd(A[rn3 + k + 3]);
+ acc30 = _mm512_fmadd_pd(a0, b0, acc30);
+ acc31 = _mm512_fmadd_pd(a1, b1, acc31);
+ acc32 = _mm512_fmadd_pd(a2, b2, acc32);
+ acc33 = _mm512_fmadd_pd(a3, b3, acc33);
+ }
+ for (; k < k_end; k++) {
+ size_t k8 = k / 8, dk = k % 8;
+ __m512d a_bcast = _mm512_set1_pd(A[rn0 + k]);
+ __m512d b_val = _mm512_load_pd(&B_packed[(j8idx * n8 + k8) * 64 + dk * 8]);
+ acc00 = _mm512_fmadd_pd(a_bcast, b_val, acc00);
+ a_bcast = _mm512_set1_pd(A[rn1 + k]);
+ acc10 = _mm512_fmadd_pd(a_bcast, b_val, acc10);
+ a_bcast = _mm512_set1_pd(A[rn2 + k]);
+ acc20 = _mm512_fmadd_pd(a_bcast, b_val, acc20);
+ a_bcast = _mm512_set1_pd(A[rn3 + k]);
+ acc30 = _mm512_fmadd_pd(a_bcast, b_val, acc30);
+ }
+ }
+
+ acc00 = _mm512_add_pd(acc00, acc01);
+ acc02 = _mm512_add_pd(acc02, acc03);
+ acc00 = _mm512_add_pd(acc00, acc02);
+ acc10 = _mm512_add_pd(acc10, acc11);
+ acc12 = _mm512_add_pd(acc12, acc13);
+ acc10 = _mm512_add_pd(acc10, acc12);
+ acc20 = _mm512_add_pd(acc20, acc21);
+ acc22 = _mm512_add_pd(acc22, acc23);
+ acc20 = _mm512_add_pd(acc20, acc22);
+ acc30 = _mm512_add_pd(acc30, acc31);
+ acc32 = _mm512_add_pd(acc32, acc33);
+ acc30 = _mm512_add_pd(acc30, acc32);
+
+ double tmp[8] __attribute__((aligned(64)));
+ _mm512_store_pd(tmp, acc00);
+ for (size_t dj = 0; dj < 8; dj++) acc[li0 * tj + lj + dj] += tmp[dj];
+ _mm512_store_pd(tmp, acc10);
+ for (size_t dj = 0; dj < 8; dj++) acc[li1 * tj + lj + dj] += tmp[dj];
+ _mm512_store_pd(tmp, acc20);
+ for (size_t dj = 0; dj < 8; dj++) acc[li2 * tj + lj + dj] += tmp[dj];
+ _mm512_store_pd(tmp, acc30);
+ for (size_t dj = 0; dj < 8; dj++) acc[li3 * tj + lj + dj] += tmp[dj];
+ }
+
+ for (size_t j = jj + (tj / 8) * 8; j < j_end; j++) {
+ size_t lj = j - jj;
+ double sum0 = 0.0;
+ double sum1 = 0.0;
+ double sum2 = 0.0;
+ double sum3 = 0.0;
+ for (size_t k = 0; k < n; k++) {
+ sum0 += A[rn0 + k] * B[k * p + j];
+ sum1 += A[rn1 + k] * B[k * p + j];
+ sum2 += A[rn2 + k] * B[k * p + j];
+ sum3 += A[rn3 + k] * B[k * p + j];
+ }
+ acc[li0 * tj + lj] += sum0;
+ acc[li1 * tj + lj] += sum1;
+ acc[li2 * tj + lj] += sum2;
+ acc[li3 * tj + lj] += sum3;
+ }
+ }
+
+ for (size_t i = ii + (i_end - ii) / 4 * 4; i < i_end; i++) {
+ size_t li = i - ii;
+ size_t rn = i * n;
+ for (size_t j = jj; j + 8 <= j_end; j += 8) {
+ size_t lj = j - jj;
+ size_t j8idx = j / 8;
+ __m512d acc0 = _mm512_setzero_pd();
+ __m512d acc1 = _mm512_setzero_pd();
+ __m512d acc2 = _mm512_setzero_pd();
+ __m512d acc3 = _mm512_setzero_pd();
+
+ for (size_t kk = 0; kk < n; kk += kb) {
+ size_t k_end = (kk + kb < n) ? kk + kb : n;
+ size_t k_end8 = kk + (k_end - kk) / 8 * 8;
+ size_t k = kk;
+ for (; k + 4 <= k_end8; k += 4) {
+ size_t k8_0 = k / 8, dk_0 = k % 8;
+ size_t k8_1 = (k + 1) / 8, dk_1 = (k + 1) % 8;
+ size_t k8_2 = (k + 2) / 8, dk_2 = (k + 2) % 8;
+ size_t k8_3 = (k + 3) / 8, dk_3 = (k + 3) % 8;
+ __m512d a0 = _mm512_set1_pd(A[rn + k]);
+ __m512d a1 = _mm512_set1_pd(A[rn + k + 1]);
+ __m512d a2 = _mm512_set1_pd(A[rn + k + 2]);
+ __m512d a3 = _mm512_set1_pd(A[rn + k + 3]);
__m512d b0 = _mm512_load_pd(&B_packed[(j8idx * n8 + k8_0) * 64 + dk_0 * 8]);
__m512d b1 = _mm512_load_pd(&B_packed[(j8idx * n8 + k8_1) * 64 + dk_1 * 8]);
__m512d b2 = _mm512_load_pd(&B_packed[(j8idx * n8 + k8_2) * 64 + dk_2 * 8]);
@@ -794,24 +1339,26 @@ int matmul_avx512_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, con
acc2 = _mm512_fmadd_pd(a2, b2, acc2);
acc3 = _mm512_fmadd_pd(a3, b3, acc3);
}
- acc0 = _mm512_add_pd(acc0, acc1);
- acc2 = _mm512_add_pd(acc2, acc3);
- acc0 = _mm512_add_pd(acc0, acc2);
for (; k < k_end; k++) {
size_t k8 = k / 8, dk = k % 8;
- __m512d a_bcast = _mm512_set1_pd(A[i * n + k]);
+ __m512d a_bcast = _mm512_set1_pd(A[rn + k]);
__m512d b_val = _mm512_load_pd(&B_packed[(j8idx * n8 + k8) * 64 + dk * 8]);
acc0 = _mm512_fmadd_pd(a_bcast, b_val, acc0);
}
- double tmp[8] __attribute__((aligned(64)));
- _mm512_store_pd(tmp, acc0);
- for (size_t dj = 0; dj < 8; dj++) acc[li * tj + lj + dj] += tmp[dj];
}
- for (size_t j = jj + (tj / 8) * 8; j < j_end; j++) {
- size_t lj = j - jj;
- for (size_t k = kk; k < k_end; k++) {
- acc[li * tj + lj] += A[i * n + k] * B[k * p + j];
- }
+
+ acc0 = _mm512_add_pd(acc0, acc1);
+ acc2 = _mm512_add_pd(acc2, acc3);
+ acc0 = _mm512_add_pd(acc0, acc2);
+
+ double tmp[8] __attribute__((aligned(64)));
+ _mm512_store_pd(tmp, acc0);
+ for (size_t dj = 0; dj < 8; dj++) acc[li * tj + lj + dj] += tmp[dj];
+ }
+ for (size_t j = jj + (tj / 8) * 8; j < j_end; j++) {
+ size_t lj = j - jj;
+ for (size_t k = 0; k < n; k++) {
+ acc[li * tj + lj] += A[rn + k] * B[k * p + j];
}
}
}
@@ -819,10 +1366,8 @@ int matmul_avx512_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, con
for (size_t i = ii; i < i_end; i++) {
size_t li = i - ii;
for (size_t j = jj; j < j_end; j++) {
- size_t lj = j - jj;
- double v = acc[li * tj + lj];
- if (scale > 1.0) v /= scale;
- C[i * p + j] = v;
+ size_t lj = j - jj;
+ C[i * p + j] = acc[li * tj + lj] * inv_scale;
}
}
}
@@ -834,8 +1379,7 @@ int matmul_avx512_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, con
for (size_t k = 0; k < n; k++) {
sum += A[i * n + k] * B[k * p + j];
}
- if (scale > 1.0) sum /= scale;
- C[i * p + j] = sum;
+ C[i * p + j] = sum * inv_scale;
}
}