matmul.c

Matrix multiplication helper library
git clone git://git.finwo.net/lib/matmul.c
Log | Files | Refs | README | LICENSE

commit afc408a3afb47a47d58034b8f0ce806b5ad47c0a
parent 3498a66cb3f602d40f3e7091207cb7c0d6a9648b
Author: finwo <finwo@pm.me>
Date:   Mon, 20 Apr 2026 02:07:11 +0200

float performance increase

Diffstat:
Msrc/matmul.c | 256+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++------------------
1 file changed, 200 insertions(+), 56 deletions(-)

diff --git a/src/matmul.c b/src/matmul.c @@ -309,6 +309,10 @@ static void pack_b_f32(size_t n, size_t p, const float *B, float *B_packed) { } int matmul_avx2_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const float *B, float *C, double scale) { + const size_t ib = 64; + const size_t jb = 64; + const size_t kb = 16; + size_t n8 = n / 8; size_t p8 = p / 8; @@ -317,24 +321,56 @@ int matmul_avx2_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const pack_b_f32(n, p, B, B_packed); #pragma omp parallel for schedule(static) - for (size_t i = 0; i < m; i++) { - for (size_t j8 = 0; j8 < p8; j8++) { - __m256 result = _mm256_setzero_ps(); - for (size_t k8 = 0; k8 < n8; k8++) { - for (size_t dk = 0; dk < 8; dk++) { - __m256 b_val = _mm256_load_ps(&B_packed[(j8 * n8 + k8) * 64 + dk * 8]); - __m256 a_bcast = _mm256_set1_ps(A[i * n + k8 * 8 + dk]); - result = _mm256_fmadd_ps(a_bcast, b_val, result); + for (size_t ii = 0; ii < m; ii += ib) { + size_t i_end = (ii + ib < m) ? ii + ib : m; + for (size_t jj = 0; jj < p8 * 8; jj += jb) { + size_t j_end = (jj + jb < p8 * 8) ? jj + jb : p8 * 8; + size_t ti = i_end - ii; + size_t tj = j_end - jj; + float acc[64 * 64]; + memset(acc, 0, ti * tj * sizeof(float)); + + for (size_t kk = 0; kk < n; kk += kb) { + size_t k_end = (kk + kb < n) ? kk + kb : n; + for (size_t i = ii; i < i_end; i++) { + size_t li = i - ii; + for (size_t j = jj; j + 8 <= j_end; j += 8) { + size_t lj = j - jj; + size_t j8idx = j / 8; + __m256 acc_vec = _mm256_setzero_ps(); + for (size_t k = kk; k < k_end; k++) { + size_t k8 = k / 8; + size_t dk = k % 8; + __m256 a_bcast = _mm256_set1_ps(A[i * n + k]); + __m256 b_val = _mm256_load_ps(&B_packed[(j8idx * n8 + k8) * 64 + dk * 8]); + acc_vec = _mm256_fmadd_ps(a_bcast, b_val, acc_vec); + } + float tmp[8] __attribute__((aligned(32))); + _mm256_store_ps(tmp, acc_vec); + for (size_t dj = 0; dj < 8; dj++) acc[li * tj + lj + dj] += tmp[dj]; + } + for (size_t j = jj + (tj / 8) * 8; j < j_end; j++) { + size_t lj = j - jj; + for (size_t k = kk; k < k_end; k++) { + acc[li * tj + lj] += (double)A[i * n + k] * (double)B[k * p + j]; + } + } } } - float tmp[8] __attribute__((aligned(32))); - _mm256_store_ps(tmp, result); - for (size_t dj = 0; dj < 8; dj++) { - float v = tmp[dj]; - if (scale > 1.0) v /= (float)scale; - C[i * p + j8 * 8 + dj] = v; + + for (size_t i = ii; i < i_end; i++) { + size_t li = i - ii; + for (size_t j = jj; j < j_end; j++) { + size_t lj = j - jj; + float v = acc[li * tj + lj]; + if (scale > 1.0) v /= (float)scale; + C[i * p + j] = v; + } } } + } + + for (size_t i = 0; i < m; i++) { for (size_t j = p8 * 8; j < p; j++) { double sum = 0.0; for (size_t k = 0; k < n; k++) { @@ -369,6 +405,10 @@ static void pack_b_f32_512(size_t n, size_t p, const float *B, float *B_packed) } int matmul_avx512_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const float *B, float *C, double scale) { + const size_t ib = 64; + const size_t jb = 64; + const size_t kb = 16; + size_t n16 = n / 16; size_t p16 = p / 16; @@ -377,24 +417,56 @@ int matmul_avx512_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, cons pack_b_f32_512(n, p, B, B_packed); #pragma omp parallel for schedule(static) - for (size_t i = 0; i < m; i++) { - for (size_t j16 = 0; j16 < p16; j16++) { - __m512 result = _mm512_setzero_ps(); - for (size_t k16 = 0; k16 < n16; k16++) { - for (size_t dk = 0; dk < 16; dk++) { - __m512 b_val = _mm512_load_ps(&B_packed[(j16 * n16 + k16) * 256 + dk * 16]); - __m512 a_bcast = _mm512_set1_ps(A[i * n + k16 * 16 + dk]); - result = _mm512_fmadd_ps(a_bcast, b_val, result); + for (size_t ii = 0; ii < m; ii += ib) { + size_t i_end = (ii + ib < m) ? ii + ib : m; + for (size_t jj = 0; jj < p16 * 16; jj += jb) { + size_t j_end = (jj + jb < p16 * 16) ? jj + jb : p16 * 16; + size_t ti = i_end - ii; + size_t tj = j_end - jj; + float acc[64 * 64]; + memset(acc, 0, ti * tj * sizeof(float)); + + for (size_t kk = 0; kk < n; kk += kb) { + size_t k_end = (kk + kb < n) ? kk + kb : n; + for (size_t i = ii; i < i_end; i++) { + size_t li = i - ii; + for (size_t j = jj; j + 16 <= j_end; j += 16) { + size_t lj = j - jj; + size_t j16idx = j / 16; + __m512 acc_vec = _mm512_setzero_ps(); + for (size_t k = kk; k < k_end; k++) { + size_t k16 = k / 16; + size_t dk = k % 16; + __m512 a_bcast = _mm512_set1_ps(A[i * n + k]); + __m512 b_val = _mm512_load_ps(&B_packed[(j16idx * n16 + k16) * 256 + dk * 16]); + acc_vec = _mm512_fmadd_ps(a_bcast, b_val, acc_vec); + } + float tmp[16] __attribute__((aligned(64))); + _mm512_store_ps(tmp, acc_vec); + for (size_t dj = 0; dj < 16; dj++) acc[li * tj + lj + dj] += tmp[dj]; + } + for (size_t j = jj + (tj / 16) * 16; j < j_end; j++) { + size_t lj = j - jj; + for (size_t k = kk; k < k_end; k++) { + acc[li * tj + lj] += (double)A[i * n + k] * (double)B[k * p + j]; + } + } } } - float tmp[16] __attribute__((aligned(64))); - _mm512_store_ps(tmp, result); - for (size_t dj = 0; dj < 16; dj++) { - float v = tmp[dj]; - if (scale > 1.0) v /= (float)scale; - C[i * p + j16 * 16 + dj] = v; + + for (size_t i = ii; i < i_end; i++) { + size_t li = i - ii; + for (size_t j = jj; j < j_end; j++) { + size_t lj = j - jj; + float v = acc[li * tj + lj]; + if (scale > 1.0) v /= (float)scale; + C[i * p + j] = v; + } } } + } + + for (size_t i = 0; i < m; i++) { for (size_t j = p16 * 16; j < p; j++) { double sum = 0.0; for (size_t k = 0; k < n; k++) { @@ -499,6 +571,10 @@ static void pack_b_f64(size_t n, size_t p, const double *B, double *B_packed) { } int matmul_avx2_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, const double *B, double *C, double scale) { + const size_t ib = 64; + const size_t jb = 64; + const size_t kb = 8; + size_t n4 = n / 4; size_t p4 = p / 4; @@ -507,24 +583,56 @@ int matmul_avx2_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, const pack_b_f64(n, p, B, B_packed); #pragma omp parallel for schedule(static) - for (size_t i = 0; i < m; i++) { - for (size_t j4 = 0; j4 < p4; j4++) { - __m256d result = _mm256_setzero_pd(); - for (size_t k4 = 0; k4 < n4; k4++) { - for (size_t dk = 0; dk < 4; dk++) { - __m256d b_val = _mm256_load_pd(&B_packed[(j4 * n4 + k4) * 16 + dk * 4]); - __m256d a_bcast = _mm256_set1_pd(A[i * n + k4 * 4 + dk]); - result = _mm256_fmadd_pd(a_bcast, b_val, result); + for (size_t ii = 0; ii < m; ii += ib) { + size_t i_end = (ii + ib < m) ? ii + ib : m; + for (size_t jj = 0; jj < p4 * 4; jj += jb) { + size_t j_end = (jj + jb < p4 * 4) ? jj + jb : p4 * 4; + size_t ti = i_end - ii; + size_t tj = j_end - jj; + double acc[64 * 64]; + memset(acc, 0, ti * tj * sizeof(double)); + + for (size_t kk = 0; kk < n; kk += kb) { + size_t k_end = (kk + kb < n) ? kk + kb : n; + for (size_t i = ii; i < i_end; i++) { + size_t li = i - ii; + for (size_t j = jj; j + 4 <= j_end; j += 4) { + size_t lj = j - jj; + size_t j4idx = j / 4; + __m256d acc_vec = _mm256_setzero_pd(); + for (size_t k = kk; k < k_end; k++) { + size_t k4 = k / 4; + size_t dk = k % 4; + __m256d a_bcast = _mm256_set1_pd(A[i * n + k]); + __m256d b_val = _mm256_load_pd(&B_packed[(j4idx * n4 + k4) * 16 + dk * 4]); + acc_vec = _mm256_fmadd_pd(a_bcast, b_val, acc_vec); + } + double tmp[4] __attribute__((aligned(32))); + _mm256_store_pd(tmp, acc_vec); + for (size_t dj = 0; dj < 4; dj++) acc[li * tj + lj + dj] += tmp[dj]; + } + for (size_t j = jj + (tj / 4) * 4; j < j_end; j++) { + size_t lj = j - jj; + for (size_t k = kk; k < k_end; k++) { + acc[li * tj + lj] += A[i * n + k] * B[k * p + j]; + } + } } } - double tmp[4] __attribute__((aligned(32))); - _mm256_store_pd(tmp, result); - for (size_t dj = 0; dj < 4; dj++) { - double v = tmp[dj]; - if (scale > 1.0) v /= scale; - C[i * p + j4 * 4 + dj] = v; + + for (size_t i = ii; i < i_end; i++) { + size_t li = i - ii; + for (size_t j = jj; j < j_end; j++) { + size_t lj = j - jj; + double v = acc[li * tj + lj]; + if (scale > 1.0) v /= scale; + C[i * p + j] = v; + } } } + } + + for (size_t i = 0; i < m; i++) { for (size_t j = p4 * 4; j < p; j++) { double sum = 0.0; for (size_t k = 0; k < n; k++) { @@ -559,6 +667,10 @@ static void pack_b_f64_512(size_t n, size_t p, const double *B, double *B_packed } int matmul_avx512_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, const double *B, double *C, double scale) { + const size_t ib = 64; + const size_t jb = 64; + const size_t kb = 8; + size_t n8 = n / 8; size_t p8 = p / 8; @@ -567,24 +679,56 @@ int matmul_avx512_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, con pack_b_f64_512(n, p, B, B_packed); #pragma omp parallel for schedule(static) - for (size_t i = 0; i < m; i++) { - for (size_t j8 = 0; j8 < p8; j8++) { - __m512d result = _mm512_setzero_pd(); - for (size_t k8 = 0; k8 < n8; k8++) { - for (size_t dk = 0; dk < 8; dk++) { - __m512d b_val = _mm512_load_pd(&B_packed[(j8 * n8 + k8) * 64 + dk * 8]); - __m512d a_bcast = _mm512_set1_pd(A[i * n + k8 * 8 + dk]); - result = _mm512_fmadd_pd(a_bcast, b_val, result); + for (size_t ii = 0; ii < m; ii += ib) { + size_t i_end = (ii + ib < m) ? ii + ib : m; + for (size_t jj = 0; jj < p8 * 8; jj += jb) { + size_t j_end = (jj + jb < p8 * 8) ? jj + jb : p8 * 8; + size_t ti = i_end - ii; + size_t tj = j_end - jj; + double acc[64 * 64]; + memset(acc, 0, ti * tj * sizeof(double)); + + for (size_t kk = 0; kk < n; kk += kb) { + size_t k_end = (kk + kb < n) ? kk + kb : n; + for (size_t i = ii; i < i_end; i++) { + size_t li = i - ii; + for (size_t j = jj; j + 8 <= j_end; j += 8) { + size_t lj = j - jj; + size_t j8idx = j / 8; + __m512d acc_vec = _mm512_setzero_pd(); + for (size_t k = kk; k < k_end; k++) { + size_t k8 = k / 8; + size_t dk = k % 8; + __m512d a_bcast = _mm512_set1_pd(A[i * n + k]); + __m512d b_val = _mm512_load_pd(&B_packed[(j8idx * n8 + k8) * 64 + dk * 8]); + acc_vec = _mm512_fmadd_pd(a_bcast, b_val, acc_vec); + } + double tmp[8] __attribute__((aligned(64))); + _mm512_store_pd(tmp, acc_vec); + for (size_t dj = 0; dj < 8; dj++) acc[li * tj + lj + dj] += tmp[dj]; + } + for (size_t j = jj + (tj / 8) * 8; j < j_end; j++) { + size_t lj = j - jj; + for (size_t k = kk; k < k_end; k++) { + acc[li * tj + lj] += A[i * n + k] * B[k * p + j]; + } + } } } - double tmp[8] __attribute__((aligned(64))); - _mm512_store_pd(tmp, result); - for (size_t dj = 0; dj < 8; dj++) { - double v = tmp[dj]; - if (scale > 1.0) v /= scale; - C[i * p + j8 * 8 + dj] = v; + + for (size_t i = ii; i < i_end; i++) { + size_t li = i - ii; + for (size_t j = jj; j < j_end; j++) { + size_t lj = j - jj; + double v = acc[li * tj + lj]; + if (scale > 1.0) v /= scale; + C[i * p + j] = v; + } } } + } + + for (size_t i = 0; i < m; i++) { for (size_t j = p8 * 8; j < p; j++) { double sum = 0.0; for (size_t k = 0; k < n; k++) {