commit 3498a66cb3f602d40f3e7091207cb7c0d6a9648b
parent 8cc0cb57ecf252548aa53e68c41deb151b119360
Author: finwo <finwo@pm.me>
Date: Sun, 19 Apr 2026 16:26:32 +0200
f32/f64 initial implementation
Diffstat:
5 files changed, 1070 insertions(+), 11 deletions(-)
diff --git a/src/matmul.c b/src/matmul.c
@@ -241,3 +241,385 @@ static int _matmul_u8_i8_u8(size_t m, size_t n, size_t p, const uint8_t *A, cons
}
int (*matmul_u8_i8_u8)(size_t, size_t, size_t, const uint8_t *, const int8_t *, uint8_t *, double) = _matmul_u8_i8_u8;
+
+/* ========================================================================== */
+/* f32_f32_f32 implementations */
+/* ========================================================================== */
+
+int matmul_scalar_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const float *B, float *C, double scale) {
+ const size_t ib = 64;
+ const size_t jb = 64;
+ const size_t kb = 16;
+
+#pragma omp parallel for schedule(static)
+ for (size_t ii = 0; ii < m; ii += ib) {
+ size_t i_end = (ii + ib < m) ? ii + ib : m;
+ for (size_t jj = 0; jj < p; jj += jb) {
+ size_t j_end = (jj + jb < p) ? jj + jb : p;
+ size_t ti = i_end - ii;
+ size_t tj = j_end - jj;
+ double acc[64 * 64];
+ memset(acc, 0, ti * tj * sizeof(double));
+
+ for (size_t kk = 0; kk < n; kk += kb) {
+ size_t k_end = (kk + kb < n) ? kk + kb : n;
+ for (size_t i = ii; i < i_end; i++) {
+ size_t li = i - ii;
+ for (size_t j = jj; j < j_end; j++) {
+ size_t lj = j - jj;
+ double sum = 0.0;
+ for (size_t k = kk; k < k_end; k++) {
+ sum += (double)A[i * n + k] * (double)B[k * p + j];
+ }
+ acc[li * tj + lj] += sum;
+ }
+ }
+ }
+
+ for (size_t i = ii; i < i_end; i++) {
+ size_t li = i - ii;
+ for (size_t j = jj; j < j_end; j++) {
+ size_t lj = j - jj;
+ double v = acc[li * tj + lj];
+ if (scale > 1.0) v /= scale;
+ C[i * p + j] = (float)v;
+ }
+ }
+ }
+ }
+ return 0;
+}
+
+#ifdef __AVX2__
+static void pack_b_f32(size_t n, size_t p, const float *B, float *B_packed) {
+ size_t n8 = n / 8;
+ size_t p8 = p / 8;
+ for (size_t j8 = 0; j8 < p8; j8++) {
+ for (size_t k8 = 0; k8 < n8; k8++) {
+ float *dst = &B_packed[(j8 * n8 + k8) * 64];
+ for (size_t dk = 0; dk < 8; dk++) {
+ size_t k = k8 * 8 + dk;
+ for (size_t dj = 0; dj < 8; dj++) {
+ size_t j = j8 * 8 + dj;
+ dst[dk * 8 + dj] = B[k * p + j];
+ }
+ }
+ }
+ }
+}
+
+int matmul_avx2_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const float *B, float *C, double scale) {
+ size_t n8 = n / 8;
+ size_t p8 = p / 8;
+
+ float *B_packed;
+ if (posix_memalign((void **)&B_packed, 64, p8 * n8 * 64 * sizeof(float)) != 0) return -1;
+ pack_b_f32(n, p, B, B_packed);
+
+#pragma omp parallel for schedule(static)
+ for (size_t i = 0; i < m; i++) {
+ for (size_t j8 = 0; j8 < p8; j8++) {
+ __m256 result = _mm256_setzero_ps();
+ for (size_t k8 = 0; k8 < n8; k8++) {
+ for (size_t dk = 0; dk < 8; dk++) {
+ __m256 b_val = _mm256_load_ps(&B_packed[(j8 * n8 + k8) * 64 + dk * 8]);
+ __m256 a_bcast = _mm256_set1_ps(A[i * n + k8 * 8 + dk]);
+ result = _mm256_fmadd_ps(a_bcast, b_val, result);
+ }
+ }
+ float tmp[8] __attribute__((aligned(32)));
+ _mm256_store_ps(tmp, result);
+ for (size_t dj = 0; dj < 8; dj++) {
+ float v = tmp[dj];
+ if (scale > 1.0) v /= (float)scale;
+ C[i * p + j8 * 8 + dj] = v;
+ }
+ }
+ for (size_t j = p8 * 8; j < p; j++) {
+ double sum = 0.0;
+ for (size_t k = 0; k < n; k++) {
+ sum += (double)A[i * n + k] * (double)B[k * p + j];
+ }
+ if (scale > 1.0) sum /= scale;
+ C[i * p + j] = (float)sum;
+ }
+ }
+
+ free(B_packed);
+ return 0;
+}
+#endif
+
+#ifdef __AVX512F__
+static void pack_b_f32_512(size_t n, size_t p, const float *B, float *B_packed) {
+ size_t n16 = n / 16;
+ size_t p16 = p / 16;
+ for (size_t j16 = 0; j16 < p16; j16++) {
+ for (size_t k16 = 0; k16 < n16; k16++) {
+ float *dst = &B_packed[(j16 * n16 + k16) * 256];
+ for (size_t dk = 0; dk < 16; dk++) {
+ size_t k = k16 * 16 + dk;
+ for (size_t dj = 0; dj < 16; dj++) {
+ size_t j = j16 * 16 + dj;
+ dst[dk * 16 + dj] = B[k * p + j];
+ }
+ }
+ }
+ }
+}
+
+int matmul_avx512_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const float *B, float *C, double scale) {
+ size_t n16 = n / 16;
+ size_t p16 = p / 16;
+
+ float *B_packed;
+ if (posix_memalign((void **)&B_packed, 64, p16 * n16 * 256 * sizeof(float)) != 0) return -1;
+ pack_b_f32_512(n, p, B, B_packed);
+
+#pragma omp parallel for schedule(static)
+ for (size_t i = 0; i < m; i++) {
+ for (size_t j16 = 0; j16 < p16; j16++) {
+ __m512 result = _mm512_setzero_ps();
+ for (size_t k16 = 0; k16 < n16; k16++) {
+ for (size_t dk = 0; dk < 16; dk++) {
+ __m512 b_val = _mm512_load_ps(&B_packed[(j16 * n16 + k16) * 256 + dk * 16]);
+ __m512 a_bcast = _mm512_set1_ps(A[i * n + k16 * 16 + dk]);
+ result = _mm512_fmadd_ps(a_bcast, b_val, result);
+ }
+ }
+ float tmp[16] __attribute__((aligned(64)));
+ _mm512_store_ps(tmp, result);
+ for (size_t dj = 0; dj < 16; dj++) {
+ float v = tmp[dj];
+ if (scale > 1.0) v /= (float)scale;
+ C[i * p + j16 * 16 + dj] = v;
+ }
+ }
+ for (size_t j = p16 * 16; j < p; j++) {
+ double sum = 0.0;
+ for (size_t k = 0; k < n; k++) {
+ sum += (double)A[i * n + k] * (double)B[k * p + j];
+ }
+ if (scale > 1.0) sum /= scale;
+ C[i * p + j] = (float)sum;
+ }
+ }
+
+ free(B_packed);
+ return 0;
+}
+#endif
+
+static int _matmul_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const float *B, float *C, double scale) {
+ static int initialized = 0;
+ if (!initialized) {
+ matmul_feature_t feat = matmul_get_feature();
+#ifdef __AVX512F__
+ if (feat & MATMUL_FLAG_AVX512)
+ matmul_f32_f32_f32 = matmul_avx512_f32_f32_f32;
+ else
+#endif
+#ifdef __AVX2__
+ if (feat & MATMUL_FLAG_AVX2)
+ matmul_f32_f32_f32 = matmul_avx2_f32_f32_f32;
+ else
+#endif
+ matmul_f32_f32_f32 = matmul_scalar_f32_f32_f32;
+ initialized = 1;
+ }
+ return matmul_f32_f32_f32(m, n, p, A, B, C, scale);
+}
+
+int (*matmul_f32_f32_f32)(size_t, size_t, size_t, const float *, const float *, float *, double) = _matmul_f32_f32_f32;
+
+/* ========================================================================== */
+/* f64_f64_f64 implementations */
+/* ========================================================================== */
+
+int matmul_scalar_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, const double *B, double *C, double scale) {
+ const size_t ib = 64;
+ const size_t jb = 64;
+ const size_t kb = 8;
+
+#pragma omp parallel for schedule(static)
+ for (size_t ii = 0; ii < m; ii += ib) {
+ size_t i_end = (ii + ib < m) ? ii + ib : m;
+ for (size_t jj = 0; jj < p; jj += jb) {
+ size_t j_end = (jj + jb < p) ? jj + jb : p;
+ size_t ti = i_end - ii;
+ size_t tj = j_end - jj;
+ double acc[64 * 64];
+ memset(acc, 0, ti * tj * sizeof(double));
+
+ for (size_t kk = 0; kk < n; kk += kb) {
+ size_t k_end = (kk + kb < n) ? kk + kb : n;
+ for (size_t i = ii; i < i_end; i++) {
+ size_t li = i - ii;
+ for (size_t j = jj; j < j_end; j++) {
+ size_t lj = j - jj;
+ double sum = 0.0;
+ for (size_t k = kk; k < k_end; k++) {
+ sum += A[i * n + k] * B[k * p + j];
+ }
+ acc[li * tj + lj] += sum;
+ }
+ }
+ }
+
+ for (size_t i = ii; i < i_end; i++) {
+ size_t li = i - ii;
+ for (size_t j = jj; j < j_end; j++) {
+ size_t lj = j - jj;
+ double v = acc[li * tj + lj];
+ if (scale > 1.0) v /= scale;
+ C[i * p + j] = v;
+ }
+ }
+ }
+ }
+ return 0;
+}
+
+#ifdef __AVX2__
+static void pack_b_f64(size_t n, size_t p, const double *B, double *B_packed) {
+ size_t n4 = n / 4;
+ size_t p4 = p / 4;
+ for (size_t j4 = 0; j4 < p4; j4++) {
+ for (size_t k4 = 0; k4 < n4; k4++) {
+ double *dst = &B_packed[(j4 * n4 + k4) * 16];
+ for (size_t dk = 0; dk < 4; dk++) {
+ size_t k = k4 * 4 + dk;
+ for (size_t dj = 0; dj < 4; dj++) {
+ size_t j = j4 * 4 + dj;
+ dst[dk * 4 + dj] = B[k * p + j];
+ }
+ }
+ }
+ }
+}
+
+int matmul_avx2_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, const double *B, double *C, double scale) {
+ size_t n4 = n / 4;
+ size_t p4 = p / 4;
+
+ double *B_packed;
+ if (posix_memalign((void **)&B_packed, 64, p4 * n4 * 16 * sizeof(double)) != 0) return -1;
+ pack_b_f64(n, p, B, B_packed);
+
+#pragma omp parallel for schedule(static)
+ for (size_t i = 0; i < m; i++) {
+ for (size_t j4 = 0; j4 < p4; j4++) {
+ __m256d result = _mm256_setzero_pd();
+ for (size_t k4 = 0; k4 < n4; k4++) {
+ for (size_t dk = 0; dk < 4; dk++) {
+ __m256d b_val = _mm256_load_pd(&B_packed[(j4 * n4 + k4) * 16 + dk * 4]);
+ __m256d a_bcast = _mm256_set1_pd(A[i * n + k4 * 4 + dk]);
+ result = _mm256_fmadd_pd(a_bcast, b_val, result);
+ }
+ }
+ double tmp[4] __attribute__((aligned(32)));
+ _mm256_store_pd(tmp, result);
+ for (size_t dj = 0; dj < 4; dj++) {
+ double v = tmp[dj];
+ if (scale > 1.0) v /= scale;
+ C[i * p + j4 * 4 + dj] = v;
+ }
+ }
+ for (size_t j = p4 * 4; j < p; j++) {
+ double sum = 0.0;
+ for (size_t k = 0; k < n; k++) {
+ sum += A[i * n + k] * B[k * p + j];
+ }
+ if (scale > 1.0) sum /= scale;
+ C[i * p + j] = sum;
+ }
+ }
+
+ free(B_packed);
+ return 0;
+}
+#endif
+
+#ifdef __AVX512F__
+static void pack_b_f64_512(size_t n, size_t p, const double *B, double *B_packed) {
+ size_t n8 = n / 8;
+ size_t p8 = p / 8;
+ for (size_t j8 = 0; j8 < p8; j8++) {
+ for (size_t k8 = 0; k8 < n8; k8++) {
+ double *dst = &B_packed[(j8 * n8 + k8) * 64];
+ for (size_t dk = 0; dk < 8; dk++) {
+ size_t k = k8 * 8 + dk;
+ for (size_t dj = 0; dj < 8; dj++) {
+ size_t j = j8 * 8 + dj;
+ dst[dk * 8 + dj] = B[k * p + j];
+ }
+ }
+ }
+ }
+}
+
+int matmul_avx512_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, const double *B, double *C, double scale) {
+ size_t n8 = n / 8;
+ size_t p8 = p / 8;
+
+ double *B_packed;
+ if (posix_memalign((void **)&B_packed, 64, p8 * n8 * 64 * sizeof(double)) != 0) return -1;
+ pack_b_f64_512(n, p, B, B_packed);
+
+#pragma omp parallel for schedule(static)
+ for (size_t i = 0; i < m; i++) {
+ for (size_t j8 = 0; j8 < p8; j8++) {
+ __m512d result = _mm512_setzero_pd();
+ for (size_t k8 = 0; k8 < n8; k8++) {
+ for (size_t dk = 0; dk < 8; dk++) {
+ __m512d b_val = _mm512_load_pd(&B_packed[(j8 * n8 + k8) * 64 + dk * 8]);
+ __m512d a_bcast = _mm512_set1_pd(A[i * n + k8 * 8 + dk]);
+ result = _mm512_fmadd_pd(a_bcast, b_val, result);
+ }
+ }
+ double tmp[8] __attribute__((aligned(64)));
+ _mm512_store_pd(tmp, result);
+ for (size_t dj = 0; dj < 8; dj++) {
+ double v = tmp[dj];
+ if (scale > 1.0) v /= scale;
+ C[i * p + j8 * 8 + dj] = v;
+ }
+ }
+ for (size_t j = p8 * 8; j < p; j++) {
+ double sum = 0.0;
+ for (size_t k = 0; k < n; k++) {
+ sum += A[i * n + k] * B[k * p + j];
+ }
+ if (scale > 1.0) sum /= scale;
+ C[i * p + j] = sum;
+ }
+ }
+
+ free(B_packed);
+ return 0;
+}
+#endif
+
+static int _matmul_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, const double *B, double *C,
+ double scale) {
+ static int initialized = 0;
+ if (!initialized) {
+ matmul_feature_t feat = matmul_get_feature();
+#ifdef __AVX512F__
+ if (feat & MATMUL_FLAG_AVX512)
+ matmul_f64_f64_f64 = matmul_avx512_f64_f64_f64;
+ else
+#endif
+#ifdef __AVX2__
+ if (feat & MATMUL_FLAG_AVX2)
+ matmul_f64_f64_f64 = matmul_avx2_f64_f64_f64;
+ else
+#endif
+ matmul_f64_f64_f64 = matmul_scalar_f64_f64_f64;
+ initialized = 1;
+ }
+ return matmul_f64_f64_f64(m, n, p, A, B, C, scale);
+}
+
+int (*matmul_f64_f64_f64)(size_t, size_t, size_t, const double *, const double *, double *,
+ double) = _matmul_f64_f64_f64;
diff --git a/src/matmul.h b/src/matmul.h
@@ -46,15 +46,27 @@ extern "C" {
#endif
extern int (*matmul_u8_i8_u8)(size_t, size_t, size_t, const uint8_t *, const int8_t *, uint8_t *, double);
+extern int (*matmul_f32_f32_f32)(size_t, size_t, size_t, const float *, const float *, float *, double);
+extern int (*matmul_f64_f64_f64)(size_t, size_t, size_t, const double *, const double *, double *, double);
#define matmul(m, n, p, A, B, C, scale) \
_Generic((A), \
uint8_t *: _Generic((B), \
int8_t *: _Generic((C), \
uint8_t *: matmul_u8_i8_u8 \
+ ) \
+ ), \
+ float *: _Generic((B), \
+ float *: _Generic((C), \
+ float *: matmul_f32_f32_f32 \
+ ) \
+ ), \
+ double *: _Generic((B), \
+ double *: _Generic((C), \
+ double *: matmul_f64_f64_f64 \
) \
) \
- )((m), (n), (p), (A), (B), (C), (scale))
+ )((m), (n), (p), (A), (B), (C), (scale))
#ifdef __cplusplus
}
diff --git a/test/benchmark.c b/test/benchmark.c
@@ -34,7 +34,7 @@ static double percentile(double *sorted, double p, int n) {
return sorted[lo] * (1.0 - frac) + sorted[hi] * frac;
}
-static void bench(size_t m, size_t n, size_t p, int runs) {
+static void bench_u8_i8_u8(size_t m, size_t n, size_t p, int runs) {
uint8_t *A = malloc(m * n);
int8_t *B = malloc(n * p);
uint8_t *C = malloc(m * p);
@@ -75,25 +75,148 @@ static void bench(size_t m, size_t n, size_t p, int runs) {
printf("%8zu x %8zu | %6.2f | %6.2f | %6.2f | %6.2f | %6.2f | %8.1f\n", m, n, percentile(times, 1, actual_runs),
percentile(times, 5, actual_runs), percentile(times, 50, actual_runs), percentile(times, 95, actual_runs),
percentile(times, 99, actual_runs), gflops);
+
+ free(A);
+ free(B);
+ free(C);
+ free(Cwarm);
+}
+
+static void bench_f32_f32_f32(size_t m, size_t n, size_t p, int runs) {
+ float *A = malloc(m * n * sizeof(float));
+ float *B = malloc(n * p * sizeof(float));
+ float *C = malloc(m * p * sizeof(float));
+ float *Cwarm = malloc(m * p * sizeof(float));
+ double times[RUNS];
+
+ if (!A || !B || !C || !Cwarm) {
+ fprintf(stderr, "OOM for %zu x %zu\n", m, n);
+ free(A);
+ free(B);
+ free(C);
+ free(Cwarm);
+ return;
+ }
+
+ for (size_t i = 0; i < m * n; i++) A[i] = (float)(rand() % 256);
+ for (size_t i = 0; i < n * p; i++) B[i] = (float)(rand() % 256);
+ memset(C, 0, m * p * sizeof(float));
+ memset(Cwarm, 0, m * p * sizeof(float));
+
+ matmul_f32_f32_f32(m, n, p, A, B, Cwarm, 0.0);
+
+ int actual_runs = runs;
+ if (m >= 4096) actual_runs = 3;
+
+ for (int r = 0; r < actual_runs; r++) {
+ memset(C, 0, m * p * sizeof(float));
+ struct timespec start = timespec_now();
+ matmul_f32_f32_f32(m, n, p, A, B, C, 0.0);
+ struct timespec end = timespec_now();
+ times[r] = timespec_diff_ms(start, end);
+ }
+
+ qsort(times, actual_runs, sizeof(double), compare_double);
+
+ double gflops = 2.0 * m * n * p / (percentile(times, 50, actual_runs) * 1e6);
+
+ printf("%8zu x %8zu | %6.2f | %6.2f | %6.2f | %6.2f | %6.2f | %8.1f\n", m, n, percentile(times, 1, actual_runs),
+ percentile(times, 5, actual_runs), percentile(times, 50, actual_runs), percentile(times, 95, actual_runs),
+ percentile(times, 99, actual_runs), gflops);
+
+ free(A);
+ free(B);
+ free(C);
+ free(Cwarm);
+}
+
+static void bench_f64_f64_f64(size_t m, size_t n, size_t p, int runs) {
+ double *A = malloc(m * n * sizeof(double));
+ double *B = malloc(n * p * sizeof(double));
+ double *C = malloc(m * p * sizeof(double));
+ double *Cwarm = malloc(m * p * sizeof(double));
+ double times[RUNS];
+
+ if (!A || !B || !C || !Cwarm) {
+ fprintf(stderr, "OOM for %zu x %zu\n", m, n);
+ free(A);
+ free(B);
+ free(C);
+ free(Cwarm);
+ return;
+ }
+
+ for (size_t i = 0; i < m * n; i++) A[i] = (double)(rand() % 256);
+ for (size_t i = 0; i < n * p; i++) B[i] = (double)(rand() % 256);
+ memset(C, 0, m * p * sizeof(double));
+ memset(Cwarm, 0, m * p * sizeof(double));
+
+ matmul_f64_f64_f64(m, n, p, A, B, Cwarm, 0.0);
+
+ int actual_runs = runs;
+ if (m >= 4096) actual_runs = 3;
+
+ for (int r = 0; r < actual_runs; r++) {
+ memset(C, 0, m * p * sizeof(double));
+ struct timespec start = timespec_now();
+ matmul_f64_f64_f64(m, n, p, A, B, C, 0.0);
+ struct timespec end = timespec_now();
+ times[r] = timespec_diff_ms(start, end);
+ }
+
+ qsort(times, actual_runs, sizeof(double), compare_double);
+
+ double gflops = 2.0 * m * n * p / (percentile(times, 50, actual_runs) * 1e6);
+
+ printf("%8zu x %8zu | %6.2f | %6.2f | %6.2f | %6.2f | %6.2f | %8.1f\n", m, n, percentile(times, 1, actual_runs),
+ percentile(times, 5, actual_runs), percentile(times, 50, actual_runs), percentile(times, 95, actual_runs),
+ percentile(times, 99, actual_runs), gflops);
+
+ free(A);
+ free(B);
+ free(C);
+ free(Cwarm);
}
int main(void) {
srand(42);
+ const size_t sizes[][3] = {
+ {16, 16, 16}, {64, 64, 64}, {256, 256, 256}, {1024, 1024, 1024}, {4096, 4096, 4096},
+ };
+ const int num_sizes = sizeof(sizes) / sizeof(sizes[0]);
+
printf("Benchmark: u8_i8_u8 matmul, %d runs per size\n", RUNS);
- printf("--------------------------------------------------------------\n");
+ printf("--------------------------------------------------------------------------------\n");
printf("%8s | %8s | %8s | %8s | %8s | %8s | %8s\n", "M x N", "1% (ms)", "5% (ms)", "50% (ms)", "95% (ms)", "99% (ms)",
"GFLOPS");
- printf("--------------------------------------------------------------\n");
+ printf("--------------------------------------------------------------------------------\n");
+ for (int i = 0; i < num_sizes; i++) {
+ bench_u8_i8_u8(sizes[i][0], sizes[i][1], sizes[i][2], RUNS);
+ }
+ printf("--------------------------------------------------------------------------------\n");
+ printf("\n");
- bench(16, 16, 16, RUNS);
- bench(64, 64, 64, RUNS);
- bench(256, 256, 256, RUNS);
- bench(1024, 1024, 1024, RUNS);
- bench(4096, 4096, 4096, RUNS);
- // bench(16384, 16384, 16384, RUNS);
+ printf("Benchmark: f32_f32_f32 matmul, %d runs per size\n", RUNS);
+ printf("--------------------------------------------------------------------------------\n");
+ printf("%8s | %8s | %8s | %8s | %8s | %8s | %8s\n", "M x N", "1% (ms)", "5% (ms)", "50% (ms)", "95% (ms)", "99% (ms)",
+ "GFLOPS");
+ printf("--------------------------------------------------------------------------------\n");
+ for (int i = 0; i < num_sizes; i++) {
+ bench_f32_f32_f32(sizes[i][0], sizes[i][1], sizes[i][2], RUNS);
+ }
+ printf("--------------------------------------------------------------------------------\n");
+ printf("\n");
- printf("--------------------------------------------------------------\n");
+ printf("Benchmark: f64_f64_f64 matmul, %d runs per size\n", RUNS);
+ printf("--------------------------------------------------------------------------------\n");
+ printf("%8s | %8s | %8s | %8s | %8s | %8s | %8s\n", "M x N", "1% (ms)", "5% (ms)", "50% (ms)", "95% (ms)", "99% (ms)",
+ "GFLOPS");
+ printf("--------------------------------------------------------------------------------\n");
+ for (int i = 0; i < num_sizes; i++) {
+ bench_f64_f64_f64(sizes[i][0], sizes[i][1], sizes[i][2], RUNS);
+ }
+ printf("--------------------------------------------------------------------------------\n");
return 0;
}
diff --git a/test/test_matmul.c b/test/test_matmul.c
@@ -230,6 +230,486 @@ static MunitResult test_dispatched_u8_i8_u8_scaled_medium(const MunitParameter *
return test_u8_i8_u8_scaled_medium("dispatched", matmul_u8_i8_u8, 0);
}
+/* ========================================================================== */
+/* f32_f32_f32 tests */
+/* ========================================================================== */
+
+static void ref_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const float *B, float *C, double scale) {
+ for (size_t i = 0; i < m; i++)
+ for (size_t j = 0; j < p; j++) {
+ double sum = 0.0;
+ for (size_t k = 0; k < n; k++) sum += (double)A[i * n + k] * (double)B[k * p + j];
+ if (scale > 1.0) sum /= scale;
+ C[i * p + j] = (float)sum;
+ }
+}
+
+static MunitResult test_f32_f32_f32_small(const char *name,
+ int (*matmul_fn)(size_t, size_t, size_t, const float *, const float *,
+ float *, double),
+ double epsilon) {
+ float A[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
+ float B[] = {1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f};
+ float C[4], E[4];
+
+ ref_f32_f32_f32(2, 3, 2, A, B, E, 0.0);
+ matmul_fn(2, 3, 2, A, B, C, 0.0);
+
+ for (int i = 0; i < 4; i++) {
+ double d = (double)E[i] - (double)C[i];
+ if (d < 0) d = -d;
+ if (d > epsilon) return MUNIT_FAIL;
+ }
+ return MUNIT_OK;
+}
+
+static MunitResult test_f32_f32_f32_medium(const char *name,
+ int (*matmul_fn)(size_t, size_t, size_t, const float *, const float *,
+ float *, double),
+ double epsilon) {
+ const size_t m = 64, n = 64, p = 64;
+ float *A = malloc(m * n * sizeof(float));
+ float *B = malloc(n * p * sizeof(float));
+ float *C = malloc(m * p * sizeof(float));
+ float *E = malloc(m * p * sizeof(float));
+ if (!A || !B || !C || !E) {
+ free(A);
+ free(B);
+ free(C);
+ free(E);
+ return MUNIT_SKIP;
+ }
+
+ for (size_t i = 0; i < m * n; i++) A[i] = (float)((i * 7 + 13) % 251);
+ for (size_t i = 0; i < n * p; i++) B[i] = (float)(((i * 11 + 17) % 211) - 105);
+ memset(C, 0, m * p * sizeof(float));
+ memset(E, 0, m * p * sizeof(float));
+
+ ref_f32_f32_f32(m, n, p, A, B, E, 0.0);
+ matmul_fn(m, n, p, A, B, C, 0.0);
+
+ for (size_t i = 0; i < m * p; i++) {
+ double d = (double)E[i] - (double)C[i];
+ if (d < 0) d = -d;
+ if (d > epsilon) {
+ free(A);
+ free(B);
+ free(C);
+ free(E);
+ return MUNIT_FAIL;
+ }
+ }
+
+ free(A);
+ free(B);
+ free(C);
+ free(E);
+ return MUNIT_OK;
+}
+
+static MunitResult test_f32_f32_f32_scaled_small(const char *name,
+ int (*matmul_fn)(size_t, size_t, size_t, const float *, const float *,
+ float *, double),
+ double epsilon) {
+ float A[] = {8.0f, 16.0f, 24.0f, 32.0f, 40.0f, 48.0f};
+ float B[] = {2.0f, 0.0f, 0.0f, 2.0f, 0.0f, 0.0f};
+ float C[4], E[4];
+
+ ref_f32_f32_f32(2, 3, 2, A, B, E, 4.0);
+ matmul_fn(2, 3, 2, A, B, C, 4.0);
+
+ for (int i = 0; i < 4; i++) {
+ double d = (double)E[i] - (double)C[i];
+ if (d < 0) d = -d;
+ if (d > epsilon) return MUNIT_FAIL;
+ }
+ return MUNIT_OK;
+}
+
+static MunitResult test_f32_f32_f32_scaled_medium(const char *name,
+ int (*matmul_fn)(size_t, size_t, size_t, const float *, const float *,
+ float *, double),
+ double epsilon) {
+ const size_t m = 64, n = 64, p = 64;
+ float *A = malloc(m * n * sizeof(float));
+ float *B = malloc(n * p * sizeof(float));
+ float *C = malloc(m * p * sizeof(float));
+ float *E = malloc(m * p * sizeof(float));
+ if (!A || !B || !C || !E) {
+ free(A);
+ free(B);
+ free(C);
+ free(E);
+ return MUNIT_SKIP;
+ }
+
+ for (size_t i = 0; i < m * n; i++) A[i] = (float)((i * 7 + 13) % 251);
+ for (size_t i = 0; i < n * p; i++) B[i] = (float)(((i * 11 + 17) % 211) - 105);
+ memset(C, 0, m * p * sizeof(float));
+ memset(E, 0, m * p * sizeof(float));
+
+ ref_f32_f32_f32(m, n, p, A, B, E, 8.0);
+ matmul_fn(m, n, p, A, B, C, 8.0);
+
+ for (size_t i = 0; i < m * p; i++) {
+ double d = (double)E[i] - (double)C[i];
+ if (d < 0) d = -d;
+ if (d > epsilon) {
+ free(A);
+ free(B);
+ free(C);
+ free(E);
+ return MUNIT_FAIL;
+ }
+ }
+
+ free(A);
+ free(B);
+ free(C);
+ free(E);
+ return MUNIT_OK;
+}
+
+static MunitResult test_scalar_f32_f32_f32(const MunitParameter *params, void *data) {
+ (void)params;
+ (void)data;
+ return test_f32_f32_f32_small("scalar", matmul_scalar_f32_f32_f32, 1e-5);
+}
+
+static MunitResult test_scalar_f32_f32_f32_medium(const MunitParameter *params, void *data) {
+ (void)params;
+ (void)data;
+ return test_f32_f32_f32_medium("scalar", matmul_scalar_f32_f32_f32, 1e-3);
+}
+
+static MunitResult test_scalar_f32_f32_f32_scaled(const MunitParameter *params, void *data) {
+ (void)params;
+ (void)data;
+ return test_f32_f32_f32_scaled_small("scalar", matmul_scalar_f32_f32_f32, 1e-5);
+}
+
+static MunitResult test_scalar_f32_f32_f32_scaled_medium(const MunitParameter *params, void *data) {
+ (void)params;
+ (void)data;
+ return test_f32_f32_f32_scaled_medium("scalar", matmul_scalar_f32_f32_f32, 1e-3);
+}
+
+#ifdef __AVX2__
+static MunitResult test_avx2_f32_f32_f32(const MunitParameter *params, void *data) {
+ (void)params;
+ (void)data;
+ return test_f32_f32_f32_small("avx2", matmul_avx2_f32_f32_f32, 1e-5);
+}
+
+static MunitResult test_avx2_f32_f32_f32_medium(const MunitParameter *params, void *data) {
+ (void)params;
+ (void)data;
+ return test_f32_f32_f32_medium("avx2", matmul_avx2_f32_f32_f32, 1e-3);
+}
+
+static MunitResult test_avx2_f32_f32_f32_scaled(const MunitParameter *params, void *data) {
+ (void)params;
+ (void)data;
+ return test_f32_f32_f32_scaled_small("avx2", matmul_avx2_f32_f32_f32, 1e-5);
+}
+
+static MunitResult test_avx2_f32_f32_f32_scaled_medium(const MunitParameter *params, void *data) {
+ (void)params;
+ (void)data;
+ return test_f32_f32_f32_scaled_medium("avx2", matmul_avx2_f32_f32_f32, 1e-3);
+}
+#endif
+
+#ifdef __AVX512F__
+static MunitResult test_avx512_f32_f32_f32(const MunitParameter *params, void *data) {
+ (void)params;
+ (void)data;
+ return test_f32_f32_f32_small("avx512", matmul_avx512_f32_f32_f32, 1e-5);
+}
+
+static MunitResult test_avx512_f32_f32_f32_medium(const MunitParameter *params, void *data) {
+ (void)params;
+ (void)data;
+ return test_f32_f32_f32_medium("avx512", matmul_avx512_f32_f32_f32, 1e-3);
+}
+
+static MunitResult test_avx512_f32_f32_f32_scaled(const MunitParameter *params, void *data) {
+ (void)params;
+ (void)data;
+ return test_f32_f32_f32_scaled_small("avx512", matmul_avx512_f32_f32_f32, 1e-5);
+}
+
+static MunitResult test_avx512_f32_f32_f32_scaled_medium(const MunitParameter *params, void *data) {
+ (void)params;
+ (void)data;
+ return test_f32_f32_f32_scaled_medium("avx512", matmul_avx512_f32_f32_f32, 1e-3);
+}
+#endif
+
+static MunitResult test_dispatched_f32_f32_f32(const MunitParameter *params, void *data) {
+ (void)params;
+ (void)data;
+ return test_f32_f32_f32_small("dispatched", matmul_f32_f32_f32, 1e-5);
+}
+
+static MunitResult test_dispatched_f32_f32_f32_medium(const MunitParameter *params, void *data) {
+ (void)params;
+ (void)data;
+ return test_f32_f32_f32_medium("dispatched", matmul_f32_f32_f32, 1e-3);
+}
+
+static MunitResult test_dispatched_f32_f32_f32_scaled(const MunitParameter *params, void *data) {
+ (void)params;
+ (void)data;
+ return test_f32_f32_f32_scaled_small("dispatched", matmul_f32_f32_f32, 1e-5);
+}
+
+static MunitResult test_dispatched_f32_f32_f32_scaled_medium(const MunitParameter *params, void *data) {
+ (void)params;
+ (void)data;
+ return test_f32_f32_f32_scaled_medium("dispatched", matmul_f32_f32_f32, 1e-3);
+}
+
+/* ========================================================================== */
+/* f64_f64_f64 tests */
+/* ========================================================================== */
+
+static void ref_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, const double *B, double *C, double scale) {
+ for (size_t i = 0; i < m; i++)
+ for (size_t j = 0; j < p; j++) {
+ double sum = 0.0;
+ for (size_t k = 0; k < n; k++) sum += A[i * n + k] * B[k * p + j];
+ if (scale > 1.0) sum /= scale;
+ C[i * p + j] = sum;
+ }
+}
+
+static MunitResult test_f64_f64_f64_small(const char *name,
+ int (*matmul_fn)(size_t, size_t, size_t, const double *, const double *,
+ double *, double),
+ double epsilon) {
+ double A[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
+ double B[] = {1.0, 0.0, 0.0, 1.0, 0.0, 0.0};
+ double C[4], E[4];
+
+ ref_f64_f64_f64(2, 3, 2, A, B, E, 0.0);
+ matmul_fn(2, 3, 2, A, B, C, 0.0);
+
+ for (int i = 0; i < 4; i++) {
+ double d = E[i] - C[i];
+ if (d < 0) d = -d;
+ if (d > epsilon) return MUNIT_FAIL;
+ }
+ return MUNIT_OK;
+}
+
+static MunitResult test_f64_f64_f64_medium(const char *name,
+ int (*matmul_fn)(size_t, size_t, size_t, const double *, const double *,
+ double *, double),
+ double epsilon) {
+ const size_t m = 64, n = 64, p = 64;
+ double *A = malloc(m * n * sizeof(double));
+ double *B = malloc(n * p * sizeof(double));
+ double *C = malloc(m * p * sizeof(double));
+ double *E = malloc(m * p * sizeof(double));
+ if (!A || !B || !C || !E) {
+ free(A);
+ free(B);
+ free(C);
+ free(E);
+ return MUNIT_SKIP;
+ }
+
+ for (size_t i = 0; i < m * n; i++) A[i] = (double)((i * 7 + 13) % 251);
+ for (size_t i = 0; i < n * p; i++) B[i] = (double)(((i * 11 + 17) % 211) - 105);
+ memset(C, 0, m * p * sizeof(double));
+ memset(E, 0, m * p * sizeof(double));
+
+ ref_f64_f64_f64(m, n, p, A, B, E, 0.0);
+ matmul_fn(m, n, p, A, B, C, 0.0);
+
+ for (size_t i = 0; i < m * p; i++) {
+ double d = E[i] - C[i];
+ if (d < 0) d = -d;
+ if (d > epsilon) {
+ free(A);
+ free(B);
+ free(C);
+ free(E);
+ return MUNIT_FAIL;
+ }
+ }
+
+ free(A);
+ free(B);
+ free(C);
+ free(E);
+ return MUNIT_OK;
+}
+
+static MunitResult test_f64_f64_f64_scaled_small(const char *name,
+ int (*matmul_fn)(size_t, size_t, size_t, const double *,
+ const double *, double *, double),
+ double epsilon) {
+ double A[] = {8.0, 16.0, 24.0, 32.0, 40.0, 48.0};
+ double B[] = {2.0, 0.0, 0.0, 2.0, 0.0, 0.0};
+ double C[4], E[4];
+
+ ref_f64_f64_f64(2, 3, 2, A, B, E, 4.0);
+ matmul_fn(2, 3, 2, A, B, C, 4.0);
+
+ for (int i = 0; i < 4; i++) {
+ double d = E[i] - C[i];
+ if (d < 0) d = -d;
+ if (d > epsilon) return MUNIT_FAIL;
+ }
+ return MUNIT_OK;
+}
+
+static MunitResult test_f64_f64_f64_scaled_medium(const char *name,
+ int (*matmul_fn)(size_t, size_t, size_t, const double *,
+ const double *, double *, double),
+ double epsilon) {
+ const size_t m = 64, n = 64, p = 64;
+ double *A = malloc(m * n * sizeof(double));
+ double *B = malloc(n * p * sizeof(double));
+ double *C = malloc(m * p * sizeof(double));
+ double *E = malloc(m * p * sizeof(double));
+ if (!A || !B || !C || !E) {
+ free(A);
+ free(B);
+ free(C);
+ free(E);
+ return MUNIT_SKIP;
+ }
+
+ for (size_t i = 0; i < m * n; i++) A[i] = (double)((i * 7 + 13) % 251);
+ for (size_t i = 0; i < n * p; i++) B[i] = (double)(((i * 11 + 17) % 211) - 105);
+ memset(C, 0, m * p * sizeof(double));
+ memset(E, 0, m * p * sizeof(double));
+
+ ref_f64_f64_f64(m, n, p, A, B, E, 8.0);
+ matmul_fn(m, n, p, A, B, C, 8.0);
+
+ for (size_t i = 0; i < m * p; i++) {
+ double d = E[i] - C[i];
+ if (d < 0) d = -d;
+ if (d > epsilon) {
+ free(A);
+ free(B);
+ free(C);
+ free(E);
+ return MUNIT_FAIL;
+ }
+ }
+
+ free(A);
+ free(B);
+ free(C);
+ free(E);
+ return MUNIT_OK;
+}
+
+static MunitResult test_scalar_f64_f64_f64(const MunitParameter *params, void *data) {
+ (void)params;
+ (void)data;
+ return test_f64_f64_f64_small("scalar", matmul_scalar_f64_f64_f64, 1e-12);
+}
+
+static MunitResult test_scalar_f64_f64_f64_medium(const MunitParameter *params, void *data) {
+ (void)params;
+ (void)data;
+ return test_f64_f64_f64_medium("scalar", matmul_scalar_f64_f64_f64, 1e-9);
+}
+
+static MunitResult test_scalar_f64_f64_f64_scaled(const MunitParameter *params, void *data) {
+ (void)params;
+ (void)data;
+ return test_f64_f64_f64_scaled_small("scalar", matmul_scalar_f64_f64_f64, 1e-12);
+}
+
+static MunitResult test_scalar_f64_f64_f64_scaled_medium(const MunitParameter *params, void *data) {
+ (void)params;
+ (void)data;
+ return test_f64_f64_f64_scaled_medium("scalar", matmul_scalar_f64_f64_f64, 1e-9);
+}
+
+#ifdef __AVX2__
+static MunitResult test_avx2_f64_f64_f64(const MunitParameter *params, void *data) {
+ (void)params;
+ (void)data;
+ return test_f64_f64_f64_small("avx2", matmul_avx2_f64_f64_f64, 1e-12);
+}
+
+static MunitResult test_avx2_f64_f64_f64_medium(const MunitParameter *params, void *data) {
+ (void)params;
+ (void)data;
+ return test_f64_f64_f64_medium("avx2", matmul_avx2_f64_f64_f64, 1e-9);
+}
+
+static MunitResult test_avx2_f64_f64_f64_scaled(const MunitParameter *params, void *data) {
+ (void)params;
+ (void)data;
+ return test_f64_f64_f64_scaled_small("avx2", matmul_avx2_f64_f64_f64, 1e-12);
+}
+
+static MunitResult test_avx2_f64_f64_f64_scaled_medium(const MunitParameter *params, void *data) {
+ (void)params;
+ (void)data;
+ return test_f64_f64_f64_scaled_medium("avx2", matmul_avx2_f64_f64_f64, 1e-9);
+}
+#endif
+
+#ifdef __AVX512F__
+static MunitResult test_avx512_f64_f64_f64(const MunitParameter *params, void *data) {
+ (void)params;
+ (void)data;
+ return test_f64_f64_f64_small("avx512", matmul_avx512_f64_f64_f64, 1e-12);
+}
+
+static MunitResult test_avx512_f64_f64_f64_medium(const MunitParameter *params, void *data) {
+ (void)params;
+ (void)data;
+ return test_f64_f64_f64_medium("avx512", matmul_avx512_f64_f64_f64, 1e-9);
+}
+
+static MunitResult test_avx512_f64_f64_f64_scaled(const MunitParameter *params, void *data) {
+ (void)params;
+ (void)data;
+ return test_f64_f64_f64_scaled_small("avx512", matmul_avx512_f64_f64_f64, 1e-12);
+}
+
+static MunitResult test_avx512_f64_f64_f64_scaled_medium(const MunitParameter *params, void *data) {
+ (void)params;
+ (void)data;
+ return test_f64_f64_f64_scaled_medium("avx512", matmul_avx512_f64_f64_f64, 1e-9);
+}
+#endif
+
+static MunitResult test_dispatched_f64_f64_f64(const MunitParameter *params, void *data) {
+ (void)params;
+ (void)data;
+ return test_f64_f64_f64_small("dispatched", matmul_f64_f64_f64, 1e-12);
+}
+
+static MunitResult test_dispatched_f64_f64_f64_medium(const MunitParameter *params, void *data) {
+ (void)params;
+ (void)data;
+ return test_f64_f64_f64_medium("dispatched", matmul_f64_f64_f64, 1e-9);
+}
+
+static MunitResult test_dispatched_f64_f64_f64_scaled(const MunitParameter *params, void *data) {
+ (void)params;
+ (void)data;
+ return test_f64_f64_f64_scaled_small("dispatched", matmul_f64_f64_f64, 1e-12);
+}
+
+static MunitResult test_dispatched_f64_f64_f64_scaled_medium(const MunitParameter *params, void *data) {
+ (void)params;
+ (void)data;
+ return test_f64_f64_f64_scaled_medium("dispatched", matmul_f64_f64_f64, 1e-9);
+}
+
static MunitTest tests[] = {
{"/scalar-u8-i8-u8", test_scalar_u8_i8_u8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
{"/scalar-u8-i8-u8-medium", test_scalar_u8_i8_u8_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
@@ -247,6 +727,52 @@ static MunitTest tests[] = {
{"/dispatched-u8-i8-u8-scaled", test_dispatched_u8_i8_u8_scaled, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
{"/dispatched-u8-i8-u8-scaled-medium", test_dispatched_u8_i8_u8_scaled_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE,
NULL},
+ {"/scalar-f32-f32-f32", test_scalar_f32_f32_f32, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
+ {"/scalar-f32-f32-f32-medium", test_scalar_f32_f32_f32_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
+ {"/scalar-f32-f32-f32-scaled", test_scalar_f32_f32_f32_scaled, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
+ {"/scalar-f32-f32-f32-scaled-medium", test_scalar_f32_f32_f32_scaled_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE,
+ NULL},
+#ifdef __AVX2__
+ {"/avx2-f32-f32-f32", test_avx2_f32_f32_f32, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
+ {"/avx2-f32-f32-f32-medium", test_avx2_f32_f32_f32_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
+ {"/avx2-f32-f32-f32-scaled", test_avx2_f32_f32_f32_scaled, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
+ {"/avx2-f32-f32-f32-scaled-medium", test_avx2_f32_f32_f32_scaled_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
+#endif
+#ifdef __AVX512F__
+ {"/avx512-f32-f32-f32", test_avx512_f32_f32_f32, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
+ {"/avx512-f32-f32-f32-medium", test_avx512_f32_f32_f32_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
+ {"/avx512-f32-f32-f32-scaled", test_avx512_f32_f32_f32_scaled, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
+ {"/avx512-f32-f32-f32-scaled-medium", test_avx512_f32_f32_f32_scaled_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE,
+ NULL},
+#endif
+ {"/dispatched-f32-f32-f32", test_dispatched_f32_f32_f32, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
+ {"/dispatched-f32-f32-f32-medium", test_dispatched_f32_f32_f32_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
+ {"/dispatched-f32-f32-f32-scaled", test_dispatched_f32_f32_f32_scaled, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
+ {"/dispatched-f32-f32-f32-scaled-medium", test_dispatched_f32_f32_f32_scaled_medium, NULL, NULL,
+ MUNIT_TEST_OPTION_NONE, NULL},
+ {"/scalar-f64-f64-f64", test_scalar_f64_f64_f64, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
+ {"/scalar-f64-f64-f64-medium", test_scalar_f64_f64_f64_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
+ {"/scalar-f64-f64-f64-scaled", test_scalar_f64_f64_f64_scaled, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
+ {"/scalar-f64-f64-f64-scaled-medium", test_scalar_f64_f64_f64_scaled_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE,
+ NULL},
+#ifdef __AVX2__
+ {"/avx2-f64-f64-f64", test_avx2_f64_f64_f64, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
+ {"/avx2-f64-f64-f64-medium", test_avx2_f64_f64_f64_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
+ {"/avx2-f64-f64-f64-scaled", test_avx2_f64_f64_f64_scaled, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
+ {"/avx2-f64-f64-f64-scaled-medium", test_avx2_f64_f64_f64_scaled_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
+#endif
+#ifdef __AVX512F__
+ {"/avx512-f64-f64-f64", test_avx512_f64_f64_f64, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
+ {"/avx512-f64-f64-f64-medium", test_avx512_f64_f64_f64_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
+ {"/avx512-f64-f64-f64-scaled", test_avx512_f64_f64_f64_scaled, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
+ {"/avx512-f64-f64-f64-scaled-medium", test_avx512_f64_f64_f64_scaled_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE,
+ NULL},
+#endif
+ {"/dispatched-f64-f64-f64", test_dispatched_f64_f64_f64, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
+ {"/dispatched-f64-f64-f64-medium", test_dispatched_f64_f64_f64_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
+ {"/dispatched-f64-f64-f64-scaled", test_dispatched_f64_f64_f64_scaled, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
+ {"/dispatched-f64-f64-f64-scaled-medium", test_dispatched_f64_f64_f64_scaled_medium, NULL, NULL,
+ MUNIT_TEST_OPTION_NONE, NULL},
{NULL, NULL, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}};
static const MunitSuite suite = {"/matmul", tests, NULL, 1, MUNIT_SUITE_OPTION_NONE};
diff --git a/test/test_matmul_simd.h b/test/test_matmul_simd.h
@@ -70,6 +70,22 @@ int matmul_avx512vnni_u8_i8_u8(size_t m, size_t n, size_t p, const uint8_t *A, c
double scale);
#endif
+int matmul_scalar_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const float *B, float *C, double scale);
+#ifdef __AVX2__
+int matmul_avx2_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const float *B, float *C, double scale);
+#endif
+#ifdef __AVX512F__
+int matmul_avx512_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const float *B, float *C, double scale);
+#endif
+
+int matmul_scalar_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, const double *B, double *C, double scale);
+#ifdef __AVX2__
+int matmul_avx2_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, const double *B, double *C, double scale);
+#endif
+#ifdef __AVX512F__
+int matmul_avx512_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, const double *B, double *C, double scale);
+#endif
+
#ifdef __cplusplus
}
#endif