commit afc408a3afb47a47d58034b8f0ce806b5ad47c0a
parent 3498a66cb3f602d40f3e7091207cb7c0d6a9648b
Author: finwo <finwo@pm.me>
Date: Mon, 20 Apr 2026 02:07:11 +0200
float performance increase
Diffstat:
| M | src/matmul.c | | | 256 | +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++------------------ |
1 file changed, 200 insertions(+), 56 deletions(-)
diff --git a/src/matmul.c b/src/matmul.c
@@ -309,6 +309,10 @@ static void pack_b_f32(size_t n, size_t p, const float *B, float *B_packed) {
}
int matmul_avx2_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const float *B, float *C, double scale) {
+ const size_t ib = 64;
+ const size_t jb = 64;
+ const size_t kb = 16;
+
size_t n8 = n / 8;
size_t p8 = p / 8;
@@ -317,24 +321,56 @@ int matmul_avx2_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const
pack_b_f32(n, p, B, B_packed);
#pragma omp parallel for schedule(static)
- for (size_t i = 0; i < m; i++) {
- for (size_t j8 = 0; j8 < p8; j8++) {
- __m256 result = _mm256_setzero_ps();
- for (size_t k8 = 0; k8 < n8; k8++) {
- for (size_t dk = 0; dk < 8; dk++) {
- __m256 b_val = _mm256_load_ps(&B_packed[(j8 * n8 + k8) * 64 + dk * 8]);
- __m256 a_bcast = _mm256_set1_ps(A[i * n + k8 * 8 + dk]);
- result = _mm256_fmadd_ps(a_bcast, b_val, result);
+ for (size_t ii = 0; ii < m; ii += ib) {
+ size_t i_end = (ii + ib < m) ? ii + ib : m;
+ for (size_t jj = 0; jj < p8 * 8; jj += jb) {
+ size_t j_end = (jj + jb < p8 * 8) ? jj + jb : p8 * 8;
+ size_t ti = i_end - ii;
+ size_t tj = j_end - jj;
+ float acc[64 * 64];
+ memset(acc, 0, ti * tj * sizeof(float));
+
+ for (size_t kk = 0; kk < n; kk += kb) {
+ size_t k_end = (kk + kb < n) ? kk + kb : n;
+ for (size_t i = ii; i < i_end; i++) {
+ size_t li = i - ii;
+ for (size_t j = jj; j + 8 <= j_end; j += 8) {
+ size_t lj = j - jj;
+ size_t j8idx = j / 8;
+ __m256 acc_vec = _mm256_setzero_ps();
+ for (size_t k = kk; k < k_end; k++) {
+ size_t k8 = k / 8;
+ size_t dk = k % 8;
+ __m256 a_bcast = _mm256_set1_ps(A[i * n + k]);
+ __m256 b_val = _mm256_load_ps(&B_packed[(j8idx * n8 + k8) * 64 + dk * 8]);
+ acc_vec = _mm256_fmadd_ps(a_bcast, b_val, acc_vec);
+ }
+ float tmp[8] __attribute__((aligned(32)));
+ _mm256_store_ps(tmp, acc_vec);
+ for (size_t dj = 0; dj < 8; dj++) acc[li * tj + lj + dj] += tmp[dj];
+ }
+ for (size_t j = jj + (tj / 8) * 8; j < j_end; j++) {
+ size_t lj = j - jj;
+ for (size_t k = kk; k < k_end; k++) {
+ acc[li * tj + lj] += (double)A[i * n + k] * (double)B[k * p + j];
+ }
+ }
}
}
- float tmp[8] __attribute__((aligned(32)));
- _mm256_store_ps(tmp, result);
- for (size_t dj = 0; dj < 8; dj++) {
- float v = tmp[dj];
- if (scale > 1.0) v /= (float)scale;
- C[i * p + j8 * 8 + dj] = v;
+
+ for (size_t i = ii; i < i_end; i++) {
+ size_t li = i - ii;
+ for (size_t j = jj; j < j_end; j++) {
+ size_t lj = j - jj;
+ float v = acc[li * tj + lj];
+ if (scale > 1.0) v /= (float)scale;
+ C[i * p + j] = v;
+ }
}
}
+ }
+
+ for (size_t i = 0; i < m; i++) {
for (size_t j = p8 * 8; j < p; j++) {
double sum = 0.0;
for (size_t k = 0; k < n; k++) {
@@ -369,6 +405,10 @@ static void pack_b_f32_512(size_t n, size_t p, const float *B, float *B_packed)
}
int matmul_avx512_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const float *B, float *C, double scale) {
+ const size_t ib = 64;
+ const size_t jb = 64;
+ const size_t kb = 16;
+
size_t n16 = n / 16;
size_t p16 = p / 16;
@@ -377,24 +417,56 @@ int matmul_avx512_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, cons
pack_b_f32_512(n, p, B, B_packed);
#pragma omp parallel for schedule(static)
- for (size_t i = 0; i < m; i++) {
- for (size_t j16 = 0; j16 < p16; j16++) {
- __m512 result = _mm512_setzero_ps();
- for (size_t k16 = 0; k16 < n16; k16++) {
- for (size_t dk = 0; dk < 16; dk++) {
- __m512 b_val = _mm512_load_ps(&B_packed[(j16 * n16 + k16) * 256 + dk * 16]);
- __m512 a_bcast = _mm512_set1_ps(A[i * n + k16 * 16 + dk]);
- result = _mm512_fmadd_ps(a_bcast, b_val, result);
+ for (size_t ii = 0; ii < m; ii += ib) {
+ size_t i_end = (ii + ib < m) ? ii + ib : m;
+ for (size_t jj = 0; jj < p16 * 16; jj += jb) {
+ size_t j_end = (jj + jb < p16 * 16) ? jj + jb : p16 * 16;
+ size_t ti = i_end - ii;
+ size_t tj = j_end - jj;
+ float acc[64 * 64];
+ memset(acc, 0, ti * tj * sizeof(float));
+
+ for (size_t kk = 0; kk < n; kk += kb) {
+ size_t k_end = (kk + kb < n) ? kk + kb : n;
+ for (size_t i = ii; i < i_end; i++) {
+ size_t li = i - ii;
+ for (size_t j = jj; j + 16 <= j_end; j += 16) {
+ size_t lj = j - jj;
+ size_t j16idx = j / 16;
+ __m512 acc_vec = _mm512_setzero_ps();
+ for (size_t k = kk; k < k_end; k++) {
+ size_t k16 = k / 16;
+ size_t dk = k % 16;
+ __m512 a_bcast = _mm512_set1_ps(A[i * n + k]);
+ __m512 b_val = _mm512_load_ps(&B_packed[(j16idx * n16 + k16) * 256 + dk * 16]);
+ acc_vec = _mm512_fmadd_ps(a_bcast, b_val, acc_vec);
+ }
+ float tmp[16] __attribute__((aligned(64)));
+ _mm512_store_ps(tmp, acc_vec);
+ for (size_t dj = 0; dj < 16; dj++) acc[li * tj + lj + dj] += tmp[dj];
+ }
+ for (size_t j = jj + (tj / 16) * 16; j < j_end; j++) {
+ size_t lj = j - jj;
+ for (size_t k = kk; k < k_end; k++) {
+ acc[li * tj + lj] += (double)A[i * n + k] * (double)B[k * p + j];
+ }
+ }
}
}
- float tmp[16] __attribute__((aligned(64)));
- _mm512_store_ps(tmp, result);
- for (size_t dj = 0; dj < 16; dj++) {
- float v = tmp[dj];
- if (scale > 1.0) v /= (float)scale;
- C[i * p + j16 * 16 + dj] = v;
+
+ for (size_t i = ii; i < i_end; i++) {
+ size_t li = i - ii;
+ for (size_t j = jj; j < j_end; j++) {
+ size_t lj = j - jj;
+ float v = acc[li * tj + lj];
+ if (scale > 1.0) v /= (float)scale;
+ C[i * p + j] = v;
+ }
}
}
+ }
+
+ for (size_t i = 0; i < m; i++) {
for (size_t j = p16 * 16; j < p; j++) {
double sum = 0.0;
for (size_t k = 0; k < n; k++) {
@@ -499,6 +571,10 @@ static void pack_b_f64(size_t n, size_t p, const double *B, double *B_packed) {
}
int matmul_avx2_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, const double *B, double *C, double scale) {
+ const size_t ib = 64;
+ const size_t jb = 64;
+ const size_t kb = 8;
+
size_t n4 = n / 4;
size_t p4 = p / 4;
@@ -507,24 +583,56 @@ int matmul_avx2_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, const
pack_b_f64(n, p, B, B_packed);
#pragma omp parallel for schedule(static)
- for (size_t i = 0; i < m; i++) {
- for (size_t j4 = 0; j4 < p4; j4++) {
- __m256d result = _mm256_setzero_pd();
- for (size_t k4 = 0; k4 < n4; k4++) {
- for (size_t dk = 0; dk < 4; dk++) {
- __m256d b_val = _mm256_load_pd(&B_packed[(j4 * n4 + k4) * 16 + dk * 4]);
- __m256d a_bcast = _mm256_set1_pd(A[i * n + k4 * 4 + dk]);
- result = _mm256_fmadd_pd(a_bcast, b_val, result);
+ for (size_t ii = 0; ii < m; ii += ib) {
+ size_t i_end = (ii + ib < m) ? ii + ib : m;
+ for (size_t jj = 0; jj < p4 * 4; jj += jb) {
+ size_t j_end = (jj + jb < p4 * 4) ? jj + jb : p4 * 4;
+ size_t ti = i_end - ii;
+ size_t tj = j_end - jj;
+ double acc[64 * 64];
+ memset(acc, 0, ti * tj * sizeof(double));
+
+ for (size_t kk = 0; kk < n; kk += kb) {
+ size_t k_end = (kk + kb < n) ? kk + kb : n;
+ for (size_t i = ii; i < i_end; i++) {
+ size_t li = i - ii;
+ for (size_t j = jj; j + 4 <= j_end; j += 4) {
+ size_t lj = j - jj;
+ size_t j4idx = j / 4;
+ __m256d acc_vec = _mm256_setzero_pd();
+ for (size_t k = kk; k < k_end; k++) {
+ size_t k4 = k / 4;
+ size_t dk = k % 4;
+ __m256d a_bcast = _mm256_set1_pd(A[i * n + k]);
+ __m256d b_val = _mm256_load_pd(&B_packed[(j4idx * n4 + k4) * 16 + dk * 4]);
+ acc_vec = _mm256_fmadd_pd(a_bcast, b_val, acc_vec);
+ }
+ double tmp[4] __attribute__((aligned(32)));
+ _mm256_store_pd(tmp, acc_vec);
+ for (size_t dj = 0; dj < 4; dj++) acc[li * tj + lj + dj] += tmp[dj];
+ }
+ for (size_t j = jj + (tj / 4) * 4; j < j_end; j++) {
+ size_t lj = j - jj;
+ for (size_t k = kk; k < k_end; k++) {
+ acc[li * tj + lj] += A[i * n + k] * B[k * p + j];
+ }
+ }
}
}
- double tmp[4] __attribute__((aligned(32)));
- _mm256_store_pd(tmp, result);
- for (size_t dj = 0; dj < 4; dj++) {
- double v = tmp[dj];
- if (scale > 1.0) v /= scale;
- C[i * p + j4 * 4 + dj] = v;
+
+ for (size_t i = ii; i < i_end; i++) {
+ size_t li = i - ii;
+ for (size_t j = jj; j < j_end; j++) {
+ size_t lj = j - jj;
+ double v = acc[li * tj + lj];
+ if (scale > 1.0) v /= scale;
+ C[i * p + j] = v;
+ }
}
}
+ }
+
+ for (size_t i = 0; i < m; i++) {
for (size_t j = p4 * 4; j < p; j++) {
double sum = 0.0;
for (size_t k = 0; k < n; k++) {
@@ -559,6 +667,10 @@ static void pack_b_f64_512(size_t n, size_t p, const double *B, double *B_packed
}
int matmul_avx512_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, const double *B, double *C, double scale) {
+ const size_t ib = 64;
+ const size_t jb = 64;
+ const size_t kb = 8;
+
size_t n8 = n / 8;
size_t p8 = p / 8;
@@ -567,24 +679,56 @@ int matmul_avx512_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, con
pack_b_f64_512(n, p, B, B_packed);
#pragma omp parallel for schedule(static)
- for (size_t i = 0; i < m; i++) {
- for (size_t j8 = 0; j8 < p8; j8++) {
- __m512d result = _mm512_setzero_pd();
- for (size_t k8 = 0; k8 < n8; k8++) {
- for (size_t dk = 0; dk < 8; dk++) {
- __m512d b_val = _mm512_load_pd(&B_packed[(j8 * n8 + k8) * 64 + dk * 8]);
- __m512d a_bcast = _mm512_set1_pd(A[i * n + k8 * 8 + dk]);
- result = _mm512_fmadd_pd(a_bcast, b_val, result);
+ for (size_t ii = 0; ii < m; ii += ib) {
+ size_t i_end = (ii + ib < m) ? ii + ib : m;
+ for (size_t jj = 0; jj < p8 * 8; jj += jb) {
+ size_t j_end = (jj + jb < p8 * 8) ? jj + jb : p8 * 8;
+ size_t ti = i_end - ii;
+ size_t tj = j_end - jj;
+ double acc[64 * 64];
+ memset(acc, 0, ti * tj * sizeof(double));
+
+ for (size_t kk = 0; kk < n; kk += kb) {
+ size_t k_end = (kk + kb < n) ? kk + kb : n;
+ for (size_t i = ii; i < i_end; i++) {
+ size_t li = i - ii;
+ for (size_t j = jj; j + 8 <= j_end; j += 8) {
+ size_t lj = j - jj;
+ size_t j8idx = j / 8;
+ __m512d acc_vec = _mm512_setzero_pd();
+ for (size_t k = kk; k < k_end; k++) {
+ size_t k8 = k / 8;
+ size_t dk = k % 8;
+ __m512d a_bcast = _mm512_set1_pd(A[i * n + k]);
+ __m512d b_val = _mm512_load_pd(&B_packed[(j8idx * n8 + k8) * 64 + dk * 8]);
+ acc_vec = _mm512_fmadd_pd(a_bcast, b_val, acc_vec);
+ }
+ double tmp[8] __attribute__((aligned(64)));
+ _mm512_store_pd(tmp, acc_vec);
+ for (size_t dj = 0; dj < 8; dj++) acc[li * tj + lj + dj] += tmp[dj];
+ }
+ for (size_t j = jj + (tj / 8) * 8; j < j_end; j++) {
+ size_t lj = j - jj;
+ for (size_t k = kk; k < k_end; k++) {
+ acc[li * tj + lj] += A[i * n + k] * B[k * p + j];
+ }
+ }
}
}
- double tmp[8] __attribute__((aligned(64)));
- _mm512_store_pd(tmp, result);
- for (size_t dj = 0; dj < 8; dj++) {
- double v = tmp[dj];
- if (scale > 1.0) v /= scale;
- C[i * p + j8 * 8 + dj] = v;
+
+ for (size_t i = ii; i < i_end; i++) {
+ size_t li = i - ii;
+ for (size_t j = jj; j < j_end; j++) {
+ size_t lj = j - jj;
+ double v = acc[li * tj + lj];
+ if (scale > 1.0) v /= scale;
+ C[i * p + j] = v;
+ }
}
}
+ }
+
+ for (size_t i = 0; i < m; i++) {
for (size_t j = p8 * 8; j < p; j++) {
double sum = 0.0;
for (size_t k = 0; k < n; k++) {