matmul.c

Matrix multiplication helper library
git clone git://git.finwo.net/lib/matmul.c
Log | Files | Refs | README | LICENSE

commit 3498a66cb3f602d40f3e7091207cb7c0d6a9648b
parent 8cc0cb57ecf252548aa53e68c41deb151b119360
Author: finwo <finwo@pm.me>
Date:   Sun, 19 Apr 2026 16:26:32 +0200

f32/f64 initial implementation

Diffstat:
Msrc/matmul.c | 382+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Msrc/matmul.h | 14+++++++++++++-
Mtest/benchmark.c | 143+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++------
Mtest/test_matmul.c | 526+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Mtest/test_matmul_simd.h | 16++++++++++++++++
5 files changed, 1070 insertions(+), 11 deletions(-)

diff --git a/src/matmul.c b/src/matmul.c @@ -241,3 +241,385 @@ static int _matmul_u8_i8_u8(size_t m, size_t n, size_t p, const uint8_t *A, cons } int (*matmul_u8_i8_u8)(size_t, size_t, size_t, const uint8_t *, const int8_t *, uint8_t *, double) = _matmul_u8_i8_u8; + +/* ========================================================================== */ +/* f32_f32_f32 implementations */ +/* ========================================================================== */ + +int matmul_scalar_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const float *B, float *C, double scale) { + const size_t ib = 64; + const size_t jb = 64; + const size_t kb = 16; + +#pragma omp parallel for schedule(static) + for (size_t ii = 0; ii < m; ii += ib) { + size_t i_end = (ii + ib < m) ? ii + ib : m; + for (size_t jj = 0; jj < p; jj += jb) { + size_t j_end = (jj + jb < p) ? jj + jb : p; + size_t ti = i_end - ii; + size_t tj = j_end - jj; + double acc[64 * 64]; + memset(acc, 0, ti * tj * sizeof(double)); + + for (size_t kk = 0; kk < n; kk += kb) { + size_t k_end = (kk + kb < n) ? kk + kb : n; + for (size_t i = ii; i < i_end; i++) { + size_t li = i - ii; + for (size_t j = jj; j < j_end; j++) { + size_t lj = j - jj; + double sum = 0.0; + for (size_t k = kk; k < k_end; k++) { + sum += (double)A[i * n + k] * (double)B[k * p + j]; + } + acc[li * tj + lj] += sum; + } + } + } + + for (size_t i = ii; i < i_end; i++) { + size_t li = i - ii; + for (size_t j = jj; j < j_end; j++) { + size_t lj = j - jj; + double v = acc[li * tj + lj]; + if (scale > 1.0) v /= scale; + C[i * p + j] = (float)v; + } + } + } + } + return 0; +} + +#ifdef __AVX2__ +static void pack_b_f32(size_t n, size_t p, const float *B, float *B_packed) { + size_t n8 = n / 8; + size_t p8 = p / 8; + for (size_t j8 = 0; j8 < p8; j8++) { + for (size_t k8 = 0; k8 < n8; k8++) { + float *dst = &B_packed[(j8 * n8 + k8) * 64]; + for (size_t dk = 0; dk < 8; dk++) { + size_t k = k8 * 8 + dk; + for (size_t dj = 0; dj < 8; dj++) { + size_t j = j8 * 8 + dj; + dst[dk * 8 + dj] = B[k * p + j]; + } + } + } + } +} + +int matmul_avx2_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const float *B, float *C, double scale) { + size_t n8 = n / 8; + size_t p8 = p / 8; + + float *B_packed; + if (posix_memalign((void **)&B_packed, 64, p8 * n8 * 64 * sizeof(float)) != 0) return -1; + pack_b_f32(n, p, B, B_packed); + +#pragma omp parallel for schedule(static) + for (size_t i = 0; i < m; i++) { + for (size_t j8 = 0; j8 < p8; j8++) { + __m256 result = _mm256_setzero_ps(); + for (size_t k8 = 0; k8 < n8; k8++) { + for (size_t dk = 0; dk < 8; dk++) { + __m256 b_val = _mm256_load_ps(&B_packed[(j8 * n8 + k8) * 64 + dk * 8]); + __m256 a_bcast = _mm256_set1_ps(A[i * n + k8 * 8 + dk]); + result = _mm256_fmadd_ps(a_bcast, b_val, result); + } + } + float tmp[8] __attribute__((aligned(32))); + _mm256_store_ps(tmp, result); + for (size_t dj = 0; dj < 8; dj++) { + float v = tmp[dj]; + if (scale > 1.0) v /= (float)scale; + C[i * p + j8 * 8 + dj] = v; + } + } + for (size_t j = p8 * 8; j < p; j++) { + double sum = 0.0; + for (size_t k = 0; k < n; k++) { + sum += (double)A[i * n + k] * (double)B[k * p + j]; + } + if (scale > 1.0) sum /= scale; + C[i * p + j] = (float)sum; + } + } + + free(B_packed); + return 0; +} +#endif + +#ifdef __AVX512F__ +static void pack_b_f32_512(size_t n, size_t p, const float *B, float *B_packed) { + size_t n16 = n / 16; + size_t p16 = p / 16; + for (size_t j16 = 0; j16 < p16; j16++) { + for (size_t k16 = 0; k16 < n16; k16++) { + float *dst = &B_packed[(j16 * n16 + k16) * 256]; + for (size_t dk = 0; dk < 16; dk++) { + size_t k = k16 * 16 + dk; + for (size_t dj = 0; dj < 16; dj++) { + size_t j = j16 * 16 + dj; + dst[dk * 16 + dj] = B[k * p + j]; + } + } + } + } +} + +int matmul_avx512_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const float *B, float *C, double scale) { + size_t n16 = n / 16; + size_t p16 = p / 16; + + float *B_packed; + if (posix_memalign((void **)&B_packed, 64, p16 * n16 * 256 * sizeof(float)) != 0) return -1; + pack_b_f32_512(n, p, B, B_packed); + +#pragma omp parallel for schedule(static) + for (size_t i = 0; i < m; i++) { + for (size_t j16 = 0; j16 < p16; j16++) { + __m512 result = _mm512_setzero_ps(); + for (size_t k16 = 0; k16 < n16; k16++) { + for (size_t dk = 0; dk < 16; dk++) { + __m512 b_val = _mm512_load_ps(&B_packed[(j16 * n16 + k16) * 256 + dk * 16]); + __m512 a_bcast = _mm512_set1_ps(A[i * n + k16 * 16 + dk]); + result = _mm512_fmadd_ps(a_bcast, b_val, result); + } + } + float tmp[16] __attribute__((aligned(64))); + _mm512_store_ps(tmp, result); + for (size_t dj = 0; dj < 16; dj++) { + float v = tmp[dj]; + if (scale > 1.0) v /= (float)scale; + C[i * p + j16 * 16 + dj] = v; + } + } + for (size_t j = p16 * 16; j < p; j++) { + double sum = 0.0; + for (size_t k = 0; k < n; k++) { + sum += (double)A[i * n + k] * (double)B[k * p + j]; + } + if (scale > 1.0) sum /= scale; + C[i * p + j] = (float)sum; + } + } + + free(B_packed); + return 0; +} +#endif + +static int _matmul_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const float *B, float *C, double scale) { + static int initialized = 0; + if (!initialized) { + matmul_feature_t feat = matmul_get_feature(); +#ifdef __AVX512F__ + if (feat & MATMUL_FLAG_AVX512) + matmul_f32_f32_f32 = matmul_avx512_f32_f32_f32; + else +#endif +#ifdef __AVX2__ + if (feat & MATMUL_FLAG_AVX2) + matmul_f32_f32_f32 = matmul_avx2_f32_f32_f32; + else +#endif + matmul_f32_f32_f32 = matmul_scalar_f32_f32_f32; + initialized = 1; + } + return matmul_f32_f32_f32(m, n, p, A, B, C, scale); +} + +int (*matmul_f32_f32_f32)(size_t, size_t, size_t, const float *, const float *, float *, double) = _matmul_f32_f32_f32; + +/* ========================================================================== */ +/* f64_f64_f64 implementations */ +/* ========================================================================== */ + +int matmul_scalar_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, const double *B, double *C, double scale) { + const size_t ib = 64; + const size_t jb = 64; + const size_t kb = 8; + +#pragma omp parallel for schedule(static) + for (size_t ii = 0; ii < m; ii += ib) { + size_t i_end = (ii + ib < m) ? ii + ib : m; + for (size_t jj = 0; jj < p; jj += jb) { + size_t j_end = (jj + jb < p) ? jj + jb : p; + size_t ti = i_end - ii; + size_t tj = j_end - jj; + double acc[64 * 64]; + memset(acc, 0, ti * tj * sizeof(double)); + + for (size_t kk = 0; kk < n; kk += kb) { + size_t k_end = (kk + kb < n) ? kk + kb : n; + for (size_t i = ii; i < i_end; i++) { + size_t li = i - ii; + for (size_t j = jj; j < j_end; j++) { + size_t lj = j - jj; + double sum = 0.0; + for (size_t k = kk; k < k_end; k++) { + sum += A[i * n + k] * B[k * p + j]; + } + acc[li * tj + lj] += sum; + } + } + } + + for (size_t i = ii; i < i_end; i++) { + size_t li = i - ii; + for (size_t j = jj; j < j_end; j++) { + size_t lj = j - jj; + double v = acc[li * tj + lj]; + if (scale > 1.0) v /= scale; + C[i * p + j] = v; + } + } + } + } + return 0; +} + +#ifdef __AVX2__ +static void pack_b_f64(size_t n, size_t p, const double *B, double *B_packed) { + size_t n4 = n / 4; + size_t p4 = p / 4; + for (size_t j4 = 0; j4 < p4; j4++) { + for (size_t k4 = 0; k4 < n4; k4++) { + double *dst = &B_packed[(j4 * n4 + k4) * 16]; + for (size_t dk = 0; dk < 4; dk++) { + size_t k = k4 * 4 + dk; + for (size_t dj = 0; dj < 4; dj++) { + size_t j = j4 * 4 + dj; + dst[dk * 4 + dj] = B[k * p + j]; + } + } + } + } +} + +int matmul_avx2_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, const double *B, double *C, double scale) { + size_t n4 = n / 4; + size_t p4 = p / 4; + + double *B_packed; + if (posix_memalign((void **)&B_packed, 64, p4 * n4 * 16 * sizeof(double)) != 0) return -1; + pack_b_f64(n, p, B, B_packed); + +#pragma omp parallel for schedule(static) + for (size_t i = 0; i < m; i++) { + for (size_t j4 = 0; j4 < p4; j4++) { + __m256d result = _mm256_setzero_pd(); + for (size_t k4 = 0; k4 < n4; k4++) { + for (size_t dk = 0; dk < 4; dk++) { + __m256d b_val = _mm256_load_pd(&B_packed[(j4 * n4 + k4) * 16 + dk * 4]); + __m256d a_bcast = _mm256_set1_pd(A[i * n + k4 * 4 + dk]); + result = _mm256_fmadd_pd(a_bcast, b_val, result); + } + } + double tmp[4] __attribute__((aligned(32))); + _mm256_store_pd(tmp, result); + for (size_t dj = 0; dj < 4; dj++) { + double v = tmp[dj]; + if (scale > 1.0) v /= scale; + C[i * p + j4 * 4 + dj] = v; + } + } + for (size_t j = p4 * 4; j < p; j++) { + double sum = 0.0; + for (size_t k = 0; k < n; k++) { + sum += A[i * n + k] * B[k * p + j]; + } + if (scale > 1.0) sum /= scale; + C[i * p + j] = sum; + } + } + + free(B_packed); + return 0; +} +#endif + +#ifdef __AVX512F__ +static void pack_b_f64_512(size_t n, size_t p, const double *B, double *B_packed) { + size_t n8 = n / 8; + size_t p8 = p / 8; + for (size_t j8 = 0; j8 < p8; j8++) { + for (size_t k8 = 0; k8 < n8; k8++) { + double *dst = &B_packed[(j8 * n8 + k8) * 64]; + for (size_t dk = 0; dk < 8; dk++) { + size_t k = k8 * 8 + dk; + for (size_t dj = 0; dj < 8; dj++) { + size_t j = j8 * 8 + dj; + dst[dk * 8 + dj] = B[k * p + j]; + } + } + } + } +} + +int matmul_avx512_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, const double *B, double *C, double scale) { + size_t n8 = n / 8; + size_t p8 = p / 8; + + double *B_packed; + if (posix_memalign((void **)&B_packed, 64, p8 * n8 * 64 * sizeof(double)) != 0) return -1; + pack_b_f64_512(n, p, B, B_packed); + +#pragma omp parallel for schedule(static) + for (size_t i = 0; i < m; i++) { + for (size_t j8 = 0; j8 < p8; j8++) { + __m512d result = _mm512_setzero_pd(); + for (size_t k8 = 0; k8 < n8; k8++) { + for (size_t dk = 0; dk < 8; dk++) { + __m512d b_val = _mm512_load_pd(&B_packed[(j8 * n8 + k8) * 64 + dk * 8]); + __m512d a_bcast = _mm512_set1_pd(A[i * n + k8 * 8 + dk]); + result = _mm512_fmadd_pd(a_bcast, b_val, result); + } + } + double tmp[8] __attribute__((aligned(64))); + _mm512_store_pd(tmp, result); + for (size_t dj = 0; dj < 8; dj++) { + double v = tmp[dj]; + if (scale > 1.0) v /= scale; + C[i * p + j8 * 8 + dj] = v; + } + } + for (size_t j = p8 * 8; j < p; j++) { + double sum = 0.0; + for (size_t k = 0; k < n; k++) { + sum += A[i * n + k] * B[k * p + j]; + } + if (scale > 1.0) sum /= scale; + C[i * p + j] = sum; + } + } + + free(B_packed); + return 0; +} +#endif + +static int _matmul_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, const double *B, double *C, + double scale) { + static int initialized = 0; + if (!initialized) { + matmul_feature_t feat = matmul_get_feature(); +#ifdef __AVX512F__ + if (feat & MATMUL_FLAG_AVX512) + matmul_f64_f64_f64 = matmul_avx512_f64_f64_f64; + else +#endif +#ifdef __AVX2__ + if (feat & MATMUL_FLAG_AVX2) + matmul_f64_f64_f64 = matmul_avx2_f64_f64_f64; + else +#endif + matmul_f64_f64_f64 = matmul_scalar_f64_f64_f64; + initialized = 1; + } + return matmul_f64_f64_f64(m, n, p, A, B, C, scale); +} + +int (*matmul_f64_f64_f64)(size_t, size_t, size_t, const double *, const double *, double *, + double) = _matmul_f64_f64_f64; diff --git a/src/matmul.h b/src/matmul.h @@ -46,15 +46,27 @@ extern "C" { #endif extern int (*matmul_u8_i8_u8)(size_t, size_t, size_t, const uint8_t *, const int8_t *, uint8_t *, double); +extern int (*matmul_f32_f32_f32)(size_t, size_t, size_t, const float *, const float *, float *, double); +extern int (*matmul_f64_f64_f64)(size_t, size_t, size_t, const double *, const double *, double *, double); #define matmul(m, n, p, A, B, C, scale) \ _Generic((A), \ uint8_t *: _Generic((B), \ int8_t *: _Generic((C), \ uint8_t *: matmul_u8_i8_u8 \ + ) \ + ), \ + float *: _Generic((B), \ + float *: _Generic((C), \ + float *: matmul_f32_f32_f32 \ + ) \ + ), \ + double *: _Generic((B), \ + double *: _Generic((C), \ + double *: matmul_f64_f64_f64 \ ) \ ) \ - )((m), (n), (p), (A), (B), (C), (scale)) + )((m), (n), (p), (A), (B), (C), (scale)) #ifdef __cplusplus } diff --git a/test/benchmark.c b/test/benchmark.c @@ -34,7 +34,7 @@ static double percentile(double *sorted, double p, int n) { return sorted[lo] * (1.0 - frac) + sorted[hi] * frac; } -static void bench(size_t m, size_t n, size_t p, int runs) { +static void bench_u8_i8_u8(size_t m, size_t n, size_t p, int runs) { uint8_t *A = malloc(m * n); int8_t *B = malloc(n * p); uint8_t *C = malloc(m * p); @@ -75,25 +75,148 @@ static void bench(size_t m, size_t n, size_t p, int runs) { printf("%8zu x %8zu | %6.2f | %6.2f | %6.2f | %6.2f | %6.2f | %8.1f\n", m, n, percentile(times, 1, actual_runs), percentile(times, 5, actual_runs), percentile(times, 50, actual_runs), percentile(times, 95, actual_runs), percentile(times, 99, actual_runs), gflops); + + free(A); + free(B); + free(C); + free(Cwarm); +} + +static void bench_f32_f32_f32(size_t m, size_t n, size_t p, int runs) { + float *A = malloc(m * n * sizeof(float)); + float *B = malloc(n * p * sizeof(float)); + float *C = malloc(m * p * sizeof(float)); + float *Cwarm = malloc(m * p * sizeof(float)); + double times[RUNS]; + + if (!A || !B || !C || !Cwarm) { + fprintf(stderr, "OOM for %zu x %zu\n", m, n); + free(A); + free(B); + free(C); + free(Cwarm); + return; + } + + for (size_t i = 0; i < m * n; i++) A[i] = (float)(rand() % 256); + for (size_t i = 0; i < n * p; i++) B[i] = (float)(rand() % 256); + memset(C, 0, m * p * sizeof(float)); + memset(Cwarm, 0, m * p * sizeof(float)); + + matmul_f32_f32_f32(m, n, p, A, B, Cwarm, 0.0); + + int actual_runs = runs; + if (m >= 4096) actual_runs = 3; + + for (int r = 0; r < actual_runs; r++) { + memset(C, 0, m * p * sizeof(float)); + struct timespec start = timespec_now(); + matmul_f32_f32_f32(m, n, p, A, B, C, 0.0); + struct timespec end = timespec_now(); + times[r] = timespec_diff_ms(start, end); + } + + qsort(times, actual_runs, sizeof(double), compare_double); + + double gflops = 2.0 * m * n * p / (percentile(times, 50, actual_runs) * 1e6); + + printf("%8zu x %8zu | %6.2f | %6.2f | %6.2f | %6.2f | %6.2f | %8.1f\n", m, n, percentile(times, 1, actual_runs), + percentile(times, 5, actual_runs), percentile(times, 50, actual_runs), percentile(times, 95, actual_runs), + percentile(times, 99, actual_runs), gflops); + + free(A); + free(B); + free(C); + free(Cwarm); +} + +static void bench_f64_f64_f64(size_t m, size_t n, size_t p, int runs) { + double *A = malloc(m * n * sizeof(double)); + double *B = malloc(n * p * sizeof(double)); + double *C = malloc(m * p * sizeof(double)); + double *Cwarm = malloc(m * p * sizeof(double)); + double times[RUNS]; + + if (!A || !B || !C || !Cwarm) { + fprintf(stderr, "OOM for %zu x %zu\n", m, n); + free(A); + free(B); + free(C); + free(Cwarm); + return; + } + + for (size_t i = 0; i < m * n; i++) A[i] = (double)(rand() % 256); + for (size_t i = 0; i < n * p; i++) B[i] = (double)(rand() % 256); + memset(C, 0, m * p * sizeof(double)); + memset(Cwarm, 0, m * p * sizeof(double)); + + matmul_f64_f64_f64(m, n, p, A, B, Cwarm, 0.0); + + int actual_runs = runs; + if (m >= 4096) actual_runs = 3; + + for (int r = 0; r < actual_runs; r++) { + memset(C, 0, m * p * sizeof(double)); + struct timespec start = timespec_now(); + matmul_f64_f64_f64(m, n, p, A, B, C, 0.0); + struct timespec end = timespec_now(); + times[r] = timespec_diff_ms(start, end); + } + + qsort(times, actual_runs, sizeof(double), compare_double); + + double gflops = 2.0 * m * n * p / (percentile(times, 50, actual_runs) * 1e6); + + printf("%8zu x %8zu | %6.2f | %6.2f | %6.2f | %6.2f | %6.2f | %8.1f\n", m, n, percentile(times, 1, actual_runs), + percentile(times, 5, actual_runs), percentile(times, 50, actual_runs), percentile(times, 95, actual_runs), + percentile(times, 99, actual_runs), gflops); + + free(A); + free(B); + free(C); + free(Cwarm); } int main(void) { srand(42); + const size_t sizes[][3] = { + {16, 16, 16}, {64, 64, 64}, {256, 256, 256}, {1024, 1024, 1024}, {4096, 4096, 4096}, + }; + const int num_sizes = sizeof(sizes) / sizeof(sizes[0]); + printf("Benchmark: u8_i8_u8 matmul, %d runs per size\n", RUNS); - printf("--------------------------------------------------------------\n"); + printf("--------------------------------------------------------------------------------\n"); printf("%8s | %8s | %8s | %8s | %8s | %8s | %8s\n", "M x N", "1% (ms)", "5% (ms)", "50% (ms)", "95% (ms)", "99% (ms)", "GFLOPS"); - printf("--------------------------------------------------------------\n"); + printf("--------------------------------------------------------------------------------\n"); + for (int i = 0; i < num_sizes; i++) { + bench_u8_i8_u8(sizes[i][0], sizes[i][1], sizes[i][2], RUNS); + } + printf("--------------------------------------------------------------------------------\n"); + printf("\n"); - bench(16, 16, 16, RUNS); - bench(64, 64, 64, RUNS); - bench(256, 256, 256, RUNS); - bench(1024, 1024, 1024, RUNS); - bench(4096, 4096, 4096, RUNS); - // bench(16384, 16384, 16384, RUNS); + printf("Benchmark: f32_f32_f32 matmul, %d runs per size\n", RUNS); + printf("--------------------------------------------------------------------------------\n"); + printf("%8s | %8s | %8s | %8s | %8s | %8s | %8s\n", "M x N", "1% (ms)", "5% (ms)", "50% (ms)", "95% (ms)", "99% (ms)", + "GFLOPS"); + printf("--------------------------------------------------------------------------------\n"); + for (int i = 0; i < num_sizes; i++) { + bench_f32_f32_f32(sizes[i][0], sizes[i][1], sizes[i][2], RUNS); + } + printf("--------------------------------------------------------------------------------\n"); + printf("\n"); - printf("--------------------------------------------------------------\n"); + printf("Benchmark: f64_f64_f64 matmul, %d runs per size\n", RUNS); + printf("--------------------------------------------------------------------------------\n"); + printf("%8s | %8s | %8s | %8s | %8s | %8s | %8s\n", "M x N", "1% (ms)", "5% (ms)", "50% (ms)", "95% (ms)", "99% (ms)", + "GFLOPS"); + printf("--------------------------------------------------------------------------------\n"); + for (int i = 0; i < num_sizes; i++) { + bench_f64_f64_f64(sizes[i][0], sizes[i][1], sizes[i][2], RUNS); + } + printf("--------------------------------------------------------------------------------\n"); return 0; } diff --git a/test/test_matmul.c b/test/test_matmul.c @@ -230,6 +230,486 @@ static MunitResult test_dispatched_u8_i8_u8_scaled_medium(const MunitParameter * return test_u8_i8_u8_scaled_medium("dispatched", matmul_u8_i8_u8, 0); } +/* ========================================================================== */ +/* f32_f32_f32 tests */ +/* ========================================================================== */ + +static void ref_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const float *B, float *C, double scale) { + for (size_t i = 0; i < m; i++) + for (size_t j = 0; j < p; j++) { + double sum = 0.0; + for (size_t k = 0; k < n; k++) sum += (double)A[i * n + k] * (double)B[k * p + j]; + if (scale > 1.0) sum /= scale; + C[i * p + j] = (float)sum; + } +} + +static MunitResult test_f32_f32_f32_small(const char *name, + int (*matmul_fn)(size_t, size_t, size_t, const float *, const float *, + float *, double), + double epsilon) { + float A[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + float B[] = {1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f}; + float C[4], E[4]; + + ref_f32_f32_f32(2, 3, 2, A, B, E, 0.0); + matmul_fn(2, 3, 2, A, B, C, 0.0); + + for (int i = 0; i < 4; i++) { + double d = (double)E[i] - (double)C[i]; + if (d < 0) d = -d; + if (d > epsilon) return MUNIT_FAIL; + } + return MUNIT_OK; +} + +static MunitResult test_f32_f32_f32_medium(const char *name, + int (*matmul_fn)(size_t, size_t, size_t, const float *, const float *, + float *, double), + double epsilon) { + const size_t m = 64, n = 64, p = 64; + float *A = malloc(m * n * sizeof(float)); + float *B = malloc(n * p * sizeof(float)); + float *C = malloc(m * p * sizeof(float)); + float *E = malloc(m * p * sizeof(float)); + if (!A || !B || !C || !E) { + free(A); + free(B); + free(C); + free(E); + return MUNIT_SKIP; + } + + for (size_t i = 0; i < m * n; i++) A[i] = (float)((i * 7 + 13) % 251); + for (size_t i = 0; i < n * p; i++) B[i] = (float)(((i * 11 + 17) % 211) - 105); + memset(C, 0, m * p * sizeof(float)); + memset(E, 0, m * p * sizeof(float)); + + ref_f32_f32_f32(m, n, p, A, B, E, 0.0); + matmul_fn(m, n, p, A, B, C, 0.0); + + for (size_t i = 0; i < m * p; i++) { + double d = (double)E[i] - (double)C[i]; + if (d < 0) d = -d; + if (d > epsilon) { + free(A); + free(B); + free(C); + free(E); + return MUNIT_FAIL; + } + } + + free(A); + free(B); + free(C); + free(E); + return MUNIT_OK; +} + +static MunitResult test_f32_f32_f32_scaled_small(const char *name, + int (*matmul_fn)(size_t, size_t, size_t, const float *, const float *, + float *, double), + double epsilon) { + float A[] = {8.0f, 16.0f, 24.0f, 32.0f, 40.0f, 48.0f}; + float B[] = {2.0f, 0.0f, 0.0f, 2.0f, 0.0f, 0.0f}; + float C[4], E[4]; + + ref_f32_f32_f32(2, 3, 2, A, B, E, 4.0); + matmul_fn(2, 3, 2, A, B, C, 4.0); + + for (int i = 0; i < 4; i++) { + double d = (double)E[i] - (double)C[i]; + if (d < 0) d = -d; + if (d > epsilon) return MUNIT_FAIL; + } + return MUNIT_OK; +} + +static MunitResult test_f32_f32_f32_scaled_medium(const char *name, + int (*matmul_fn)(size_t, size_t, size_t, const float *, const float *, + float *, double), + double epsilon) { + const size_t m = 64, n = 64, p = 64; + float *A = malloc(m * n * sizeof(float)); + float *B = malloc(n * p * sizeof(float)); + float *C = malloc(m * p * sizeof(float)); + float *E = malloc(m * p * sizeof(float)); + if (!A || !B || !C || !E) { + free(A); + free(B); + free(C); + free(E); + return MUNIT_SKIP; + } + + for (size_t i = 0; i < m * n; i++) A[i] = (float)((i * 7 + 13) % 251); + for (size_t i = 0; i < n * p; i++) B[i] = (float)(((i * 11 + 17) % 211) - 105); + memset(C, 0, m * p * sizeof(float)); + memset(E, 0, m * p * sizeof(float)); + + ref_f32_f32_f32(m, n, p, A, B, E, 8.0); + matmul_fn(m, n, p, A, B, C, 8.0); + + for (size_t i = 0; i < m * p; i++) { + double d = (double)E[i] - (double)C[i]; + if (d < 0) d = -d; + if (d > epsilon) { + free(A); + free(B); + free(C); + free(E); + return MUNIT_FAIL; + } + } + + free(A); + free(B); + free(C); + free(E); + return MUNIT_OK; +} + +static MunitResult test_scalar_f32_f32_f32(const MunitParameter *params, void *data) { + (void)params; + (void)data; + return test_f32_f32_f32_small("scalar", matmul_scalar_f32_f32_f32, 1e-5); +} + +static MunitResult test_scalar_f32_f32_f32_medium(const MunitParameter *params, void *data) { + (void)params; + (void)data; + return test_f32_f32_f32_medium("scalar", matmul_scalar_f32_f32_f32, 1e-3); +} + +static MunitResult test_scalar_f32_f32_f32_scaled(const MunitParameter *params, void *data) { + (void)params; + (void)data; + return test_f32_f32_f32_scaled_small("scalar", matmul_scalar_f32_f32_f32, 1e-5); +} + +static MunitResult test_scalar_f32_f32_f32_scaled_medium(const MunitParameter *params, void *data) { + (void)params; + (void)data; + return test_f32_f32_f32_scaled_medium("scalar", matmul_scalar_f32_f32_f32, 1e-3); +} + +#ifdef __AVX2__ +static MunitResult test_avx2_f32_f32_f32(const MunitParameter *params, void *data) { + (void)params; + (void)data; + return test_f32_f32_f32_small("avx2", matmul_avx2_f32_f32_f32, 1e-5); +} + +static MunitResult test_avx2_f32_f32_f32_medium(const MunitParameter *params, void *data) { + (void)params; + (void)data; + return test_f32_f32_f32_medium("avx2", matmul_avx2_f32_f32_f32, 1e-3); +} + +static MunitResult test_avx2_f32_f32_f32_scaled(const MunitParameter *params, void *data) { + (void)params; + (void)data; + return test_f32_f32_f32_scaled_small("avx2", matmul_avx2_f32_f32_f32, 1e-5); +} + +static MunitResult test_avx2_f32_f32_f32_scaled_medium(const MunitParameter *params, void *data) { + (void)params; + (void)data; + return test_f32_f32_f32_scaled_medium("avx2", matmul_avx2_f32_f32_f32, 1e-3); +} +#endif + +#ifdef __AVX512F__ +static MunitResult test_avx512_f32_f32_f32(const MunitParameter *params, void *data) { + (void)params; + (void)data; + return test_f32_f32_f32_small("avx512", matmul_avx512_f32_f32_f32, 1e-5); +} + +static MunitResult test_avx512_f32_f32_f32_medium(const MunitParameter *params, void *data) { + (void)params; + (void)data; + return test_f32_f32_f32_medium("avx512", matmul_avx512_f32_f32_f32, 1e-3); +} + +static MunitResult test_avx512_f32_f32_f32_scaled(const MunitParameter *params, void *data) { + (void)params; + (void)data; + return test_f32_f32_f32_scaled_small("avx512", matmul_avx512_f32_f32_f32, 1e-5); +} + +static MunitResult test_avx512_f32_f32_f32_scaled_medium(const MunitParameter *params, void *data) { + (void)params; + (void)data; + return test_f32_f32_f32_scaled_medium("avx512", matmul_avx512_f32_f32_f32, 1e-3); +} +#endif + +static MunitResult test_dispatched_f32_f32_f32(const MunitParameter *params, void *data) { + (void)params; + (void)data; + return test_f32_f32_f32_small("dispatched", matmul_f32_f32_f32, 1e-5); +} + +static MunitResult test_dispatched_f32_f32_f32_medium(const MunitParameter *params, void *data) { + (void)params; + (void)data; + return test_f32_f32_f32_medium("dispatched", matmul_f32_f32_f32, 1e-3); +} + +static MunitResult test_dispatched_f32_f32_f32_scaled(const MunitParameter *params, void *data) { + (void)params; + (void)data; + return test_f32_f32_f32_scaled_small("dispatched", matmul_f32_f32_f32, 1e-5); +} + +static MunitResult test_dispatched_f32_f32_f32_scaled_medium(const MunitParameter *params, void *data) { + (void)params; + (void)data; + return test_f32_f32_f32_scaled_medium("dispatched", matmul_f32_f32_f32, 1e-3); +} + +/* ========================================================================== */ +/* f64_f64_f64 tests */ +/* ========================================================================== */ + +static void ref_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, const double *B, double *C, double scale) { + for (size_t i = 0; i < m; i++) + for (size_t j = 0; j < p; j++) { + double sum = 0.0; + for (size_t k = 0; k < n; k++) sum += A[i * n + k] * B[k * p + j]; + if (scale > 1.0) sum /= scale; + C[i * p + j] = sum; + } +} + +static MunitResult test_f64_f64_f64_small(const char *name, + int (*matmul_fn)(size_t, size_t, size_t, const double *, const double *, + double *, double), + double epsilon) { + double A[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}; + double B[] = {1.0, 0.0, 0.0, 1.0, 0.0, 0.0}; + double C[4], E[4]; + + ref_f64_f64_f64(2, 3, 2, A, B, E, 0.0); + matmul_fn(2, 3, 2, A, B, C, 0.0); + + for (int i = 0; i < 4; i++) { + double d = E[i] - C[i]; + if (d < 0) d = -d; + if (d > epsilon) return MUNIT_FAIL; + } + return MUNIT_OK; +} + +static MunitResult test_f64_f64_f64_medium(const char *name, + int (*matmul_fn)(size_t, size_t, size_t, const double *, const double *, + double *, double), + double epsilon) { + const size_t m = 64, n = 64, p = 64; + double *A = malloc(m * n * sizeof(double)); + double *B = malloc(n * p * sizeof(double)); + double *C = malloc(m * p * sizeof(double)); + double *E = malloc(m * p * sizeof(double)); + if (!A || !B || !C || !E) { + free(A); + free(B); + free(C); + free(E); + return MUNIT_SKIP; + } + + for (size_t i = 0; i < m * n; i++) A[i] = (double)((i * 7 + 13) % 251); + for (size_t i = 0; i < n * p; i++) B[i] = (double)(((i * 11 + 17) % 211) - 105); + memset(C, 0, m * p * sizeof(double)); + memset(E, 0, m * p * sizeof(double)); + + ref_f64_f64_f64(m, n, p, A, B, E, 0.0); + matmul_fn(m, n, p, A, B, C, 0.0); + + for (size_t i = 0; i < m * p; i++) { + double d = E[i] - C[i]; + if (d < 0) d = -d; + if (d > epsilon) { + free(A); + free(B); + free(C); + free(E); + return MUNIT_FAIL; + } + } + + free(A); + free(B); + free(C); + free(E); + return MUNIT_OK; +} + +static MunitResult test_f64_f64_f64_scaled_small(const char *name, + int (*matmul_fn)(size_t, size_t, size_t, const double *, + const double *, double *, double), + double epsilon) { + double A[] = {8.0, 16.0, 24.0, 32.0, 40.0, 48.0}; + double B[] = {2.0, 0.0, 0.0, 2.0, 0.0, 0.0}; + double C[4], E[4]; + + ref_f64_f64_f64(2, 3, 2, A, B, E, 4.0); + matmul_fn(2, 3, 2, A, B, C, 4.0); + + for (int i = 0; i < 4; i++) { + double d = E[i] - C[i]; + if (d < 0) d = -d; + if (d > epsilon) return MUNIT_FAIL; + } + return MUNIT_OK; +} + +static MunitResult test_f64_f64_f64_scaled_medium(const char *name, + int (*matmul_fn)(size_t, size_t, size_t, const double *, + const double *, double *, double), + double epsilon) { + const size_t m = 64, n = 64, p = 64; + double *A = malloc(m * n * sizeof(double)); + double *B = malloc(n * p * sizeof(double)); + double *C = malloc(m * p * sizeof(double)); + double *E = malloc(m * p * sizeof(double)); + if (!A || !B || !C || !E) { + free(A); + free(B); + free(C); + free(E); + return MUNIT_SKIP; + } + + for (size_t i = 0; i < m * n; i++) A[i] = (double)((i * 7 + 13) % 251); + for (size_t i = 0; i < n * p; i++) B[i] = (double)(((i * 11 + 17) % 211) - 105); + memset(C, 0, m * p * sizeof(double)); + memset(E, 0, m * p * sizeof(double)); + + ref_f64_f64_f64(m, n, p, A, B, E, 8.0); + matmul_fn(m, n, p, A, B, C, 8.0); + + for (size_t i = 0; i < m * p; i++) { + double d = E[i] - C[i]; + if (d < 0) d = -d; + if (d > epsilon) { + free(A); + free(B); + free(C); + free(E); + return MUNIT_FAIL; + } + } + + free(A); + free(B); + free(C); + free(E); + return MUNIT_OK; +} + +static MunitResult test_scalar_f64_f64_f64(const MunitParameter *params, void *data) { + (void)params; + (void)data; + return test_f64_f64_f64_small("scalar", matmul_scalar_f64_f64_f64, 1e-12); +} + +static MunitResult test_scalar_f64_f64_f64_medium(const MunitParameter *params, void *data) { + (void)params; + (void)data; + return test_f64_f64_f64_medium("scalar", matmul_scalar_f64_f64_f64, 1e-9); +} + +static MunitResult test_scalar_f64_f64_f64_scaled(const MunitParameter *params, void *data) { + (void)params; + (void)data; + return test_f64_f64_f64_scaled_small("scalar", matmul_scalar_f64_f64_f64, 1e-12); +} + +static MunitResult test_scalar_f64_f64_f64_scaled_medium(const MunitParameter *params, void *data) { + (void)params; + (void)data; + return test_f64_f64_f64_scaled_medium("scalar", matmul_scalar_f64_f64_f64, 1e-9); +} + +#ifdef __AVX2__ +static MunitResult test_avx2_f64_f64_f64(const MunitParameter *params, void *data) { + (void)params; + (void)data; + return test_f64_f64_f64_small("avx2", matmul_avx2_f64_f64_f64, 1e-12); +} + +static MunitResult test_avx2_f64_f64_f64_medium(const MunitParameter *params, void *data) { + (void)params; + (void)data; + return test_f64_f64_f64_medium("avx2", matmul_avx2_f64_f64_f64, 1e-9); +} + +static MunitResult test_avx2_f64_f64_f64_scaled(const MunitParameter *params, void *data) { + (void)params; + (void)data; + return test_f64_f64_f64_scaled_small("avx2", matmul_avx2_f64_f64_f64, 1e-12); +} + +static MunitResult test_avx2_f64_f64_f64_scaled_medium(const MunitParameter *params, void *data) { + (void)params; + (void)data; + return test_f64_f64_f64_scaled_medium("avx2", matmul_avx2_f64_f64_f64, 1e-9); +} +#endif + +#ifdef __AVX512F__ +static MunitResult test_avx512_f64_f64_f64(const MunitParameter *params, void *data) { + (void)params; + (void)data; + return test_f64_f64_f64_small("avx512", matmul_avx512_f64_f64_f64, 1e-12); +} + +static MunitResult test_avx512_f64_f64_f64_medium(const MunitParameter *params, void *data) { + (void)params; + (void)data; + return test_f64_f64_f64_medium("avx512", matmul_avx512_f64_f64_f64, 1e-9); +} + +static MunitResult test_avx512_f64_f64_f64_scaled(const MunitParameter *params, void *data) { + (void)params; + (void)data; + return test_f64_f64_f64_scaled_small("avx512", matmul_avx512_f64_f64_f64, 1e-12); +} + +static MunitResult test_avx512_f64_f64_f64_scaled_medium(const MunitParameter *params, void *data) { + (void)params; + (void)data; + return test_f64_f64_f64_scaled_medium("avx512", matmul_avx512_f64_f64_f64, 1e-9); +} +#endif + +static MunitResult test_dispatched_f64_f64_f64(const MunitParameter *params, void *data) { + (void)params; + (void)data; + return test_f64_f64_f64_small("dispatched", matmul_f64_f64_f64, 1e-12); +} + +static MunitResult test_dispatched_f64_f64_f64_medium(const MunitParameter *params, void *data) { + (void)params; + (void)data; + return test_f64_f64_f64_medium("dispatched", matmul_f64_f64_f64, 1e-9); +} + +static MunitResult test_dispatched_f64_f64_f64_scaled(const MunitParameter *params, void *data) { + (void)params; + (void)data; + return test_f64_f64_f64_scaled_small("dispatched", matmul_f64_f64_f64, 1e-12); +} + +static MunitResult test_dispatched_f64_f64_f64_scaled_medium(const MunitParameter *params, void *data) { + (void)params; + (void)data; + return test_f64_f64_f64_scaled_medium("dispatched", matmul_f64_f64_f64, 1e-9); +} + static MunitTest tests[] = { {"/scalar-u8-i8-u8", test_scalar_u8_i8_u8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, {"/scalar-u8-i8-u8-medium", test_scalar_u8_i8_u8_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, @@ -247,6 +727,52 @@ static MunitTest tests[] = { {"/dispatched-u8-i8-u8-scaled", test_dispatched_u8_i8_u8_scaled, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, {"/dispatched-u8-i8-u8-scaled-medium", test_dispatched_u8_i8_u8_scaled_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, + {"/scalar-f32-f32-f32", test_scalar_f32_f32_f32, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, + {"/scalar-f32-f32-f32-medium", test_scalar_f32_f32_f32_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, + {"/scalar-f32-f32-f32-scaled", test_scalar_f32_f32_f32_scaled, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, + {"/scalar-f32-f32-f32-scaled-medium", test_scalar_f32_f32_f32_scaled_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, + NULL}, +#ifdef __AVX2__ + {"/avx2-f32-f32-f32", test_avx2_f32_f32_f32, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, + {"/avx2-f32-f32-f32-medium", test_avx2_f32_f32_f32_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, + {"/avx2-f32-f32-f32-scaled", test_avx2_f32_f32_f32_scaled, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, + {"/avx2-f32-f32-f32-scaled-medium", test_avx2_f32_f32_f32_scaled_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, +#endif +#ifdef __AVX512F__ + {"/avx512-f32-f32-f32", test_avx512_f32_f32_f32, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, + {"/avx512-f32-f32-f32-medium", test_avx512_f32_f32_f32_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, + {"/avx512-f32-f32-f32-scaled", test_avx512_f32_f32_f32_scaled, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, + {"/avx512-f32-f32-f32-scaled-medium", test_avx512_f32_f32_f32_scaled_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, + NULL}, +#endif + {"/dispatched-f32-f32-f32", test_dispatched_f32_f32_f32, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, + {"/dispatched-f32-f32-f32-medium", test_dispatched_f32_f32_f32_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, + {"/dispatched-f32-f32-f32-scaled", test_dispatched_f32_f32_f32_scaled, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, + {"/dispatched-f32-f32-f32-scaled-medium", test_dispatched_f32_f32_f32_scaled_medium, NULL, NULL, + MUNIT_TEST_OPTION_NONE, NULL}, + {"/scalar-f64-f64-f64", test_scalar_f64_f64_f64, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, + {"/scalar-f64-f64-f64-medium", test_scalar_f64_f64_f64_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, + {"/scalar-f64-f64-f64-scaled", test_scalar_f64_f64_f64_scaled, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, + {"/scalar-f64-f64-f64-scaled-medium", test_scalar_f64_f64_f64_scaled_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, + NULL}, +#ifdef __AVX2__ + {"/avx2-f64-f64-f64", test_avx2_f64_f64_f64, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, + {"/avx2-f64-f64-f64-medium", test_avx2_f64_f64_f64_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, + {"/avx2-f64-f64-f64-scaled", test_avx2_f64_f64_f64_scaled, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, + {"/avx2-f64-f64-f64-scaled-medium", test_avx2_f64_f64_f64_scaled_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, +#endif +#ifdef __AVX512F__ + {"/avx512-f64-f64-f64", test_avx512_f64_f64_f64, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, + {"/avx512-f64-f64-f64-medium", test_avx512_f64_f64_f64_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, + {"/avx512-f64-f64-f64-scaled", test_avx512_f64_f64_f64_scaled, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, + {"/avx512-f64-f64-f64-scaled-medium", test_avx512_f64_f64_f64_scaled_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, + NULL}, +#endif + {"/dispatched-f64-f64-f64", test_dispatched_f64_f64_f64, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, + {"/dispatched-f64-f64-f64-medium", test_dispatched_f64_f64_f64_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, + {"/dispatched-f64-f64-f64-scaled", test_dispatched_f64_f64_f64_scaled, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, + {"/dispatched-f64-f64-f64-scaled-medium", test_dispatched_f64_f64_f64_scaled_medium, NULL, NULL, + MUNIT_TEST_OPTION_NONE, NULL}, {NULL, NULL, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}}; static const MunitSuite suite = {"/matmul", tests, NULL, 1, MUNIT_SUITE_OPTION_NONE}; diff --git a/test/test_matmul_simd.h b/test/test_matmul_simd.h @@ -70,6 +70,22 @@ int matmul_avx512vnni_u8_i8_u8(size_t m, size_t n, size_t p, const uint8_t *A, c double scale); #endif +int matmul_scalar_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const float *B, float *C, double scale); +#ifdef __AVX2__ +int matmul_avx2_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const float *B, float *C, double scale); +#endif +#ifdef __AVX512F__ +int matmul_avx512_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const float *B, float *C, double scale); +#endif + +int matmul_scalar_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, const double *B, double *C, double scale); +#ifdef __AVX2__ +int matmul_avx2_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, const double *B, double *C, double scale); +#endif +#ifdef __AVX512F__ +int matmul_avx512_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, const double *B, double *C, double scale); +#endif + #ifdef __cplusplus } #endif