matmul.c

Matrix multiplication helper library
git clone git://git.finwo.net/lib/matmul.c
Log | Files | Refs | README | LICENSE

commit a61e69663358bb8e49a1e625b4b564f4b1e756f9
parent 2ac74e34bb55a7f60b8a1ee4e9d68ce93924efa8
Author: finwo <finwo@pm.me>
Date:   Tue, 21 Apr 2026 21:35:02 +0200

Macro tested now

Diffstat:
Msrc/matmul.h | 22+++++++++++++++++++---
Mtest/test_matmul.c | 355+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 files changed, 374 insertions(+), 3 deletions(-)

diff --git a/src/matmul.h b/src/matmul.h @@ -45,9 +45,25 @@ extern "C" { #endif -extern int (*matmul_u8_i8_u8)(size_t, size_t, size_t, const uint8_t *, const int8_t *, uint8_t *, double); -extern int (*matmul_f32_f32_f32)(size_t, size_t, size_t, const float *, const float *, float *, double); -extern int (*matmul_f64_f64_f64)(size_t, size_t, size_t, const double *, const double *, double *, double); +#define __matmul_TYPE_u8 uint8_t +#define __matmul_TYPE_i8 int8_t +#define __matmul_TYPE_f32 float +#define __matmul_TYPE_f64 double + +#define __matmul_EXTERN(A, B, C) extern int (*matmul_##A##_##B##_##C)(size_t, size_t, size_t, const __matmul_TYPE_##A *, const __matmul_TYPE_##B *, __matmul_TYPE_##C *, double); + +#define matmul_externs \ + __matmul_EXTERN(u8 , i8 , u8) __matmul_EXTERN(u8 , i8 , f32) __matmul_EXTERN(u8 , i8 , f64) \ + __matmul_EXTERN(u8 , f32, u8) __matmul_EXTERN(u8 , f32, f32) __matmul_EXTERN(u8 , f32, f64) \ + __matmul_EXTERN(u8 , f64, u8) __matmul_EXTERN(u8 , f64, f32) __matmul_EXTERN(u8 , f64, f64) \ + __matmul_EXTERN(f32, i8 , u8) __matmul_EXTERN(f32, i8 , f32) __matmul_EXTERN(f32, i8 , f64) \ + __matmul_EXTERN(f32, f32, u8) __matmul_EXTERN(f32, f32, f32) __matmul_EXTERN(f32, f32, f64) \ + __matmul_EXTERN(f32, f64, u8) __matmul_EXTERN(f32, f64, f32) __matmul_EXTERN(f32, f64, f64) \ + __matmul_EXTERN(f64, i8 , u8) __matmul_EXTERN(f64, i8 , f32) __matmul_EXTERN(f64, i8 , f64) \ + __matmul_EXTERN(f64, f32, u8) __matmul_EXTERN(f64, f32, f32) __matmul_EXTERN(f64, f32, f64) \ + __matmul_EXTERN(f64, f64, u8) __matmul_EXTERN(f64, f64, f32) __matmul_EXTERN(f64, f64, f64) + +matmul_externs #define __matmul_C(type_a,type_b) \ _Generic((C), \ diff --git a/test/test_matmul.c b/test/test_matmul.c @@ -710,6 +710,345 @@ static MunitResult test_dispatched_f64_f64_f64_scaled_medium(const MunitParamete return test_f64_f64_f64_scaled_medium("dispatched", matmul_f64_f64_f64, 1e-9); } +/* ========================================================================== */ +/* matmul generic macro tests */ +/* ========================================================================== */ + +static MunitResult test_matmul_u8_i8_u8(const char *name, const char *type) { + const uint8_t A[] = {1, 2, 3, 4, 5, 6}; + const int8_t B[] = {1, 0, 0, 1, 0, 0}; + uint8_t C[4], E[4]; + + ref_u8_i8_u8(2, 3, 2, A, B, E, 0.0); + matmul(2, 3, 2, A, B, C, 0.0); + + for (int i = 0; i < 4; i++) { + int d = E[i] - C[i]; + if (d < 0) d = -d; + if (d > 0) return MUNIT_FAIL; + } + return MUNIT_OK; +} + +static MunitResult test_matmul_u8_i8_u8_scaled(const char *name, const char *type) { + const uint8_t A[] = {8, 16, 24, 32, 40, 48}; + const int8_t B[] = {2, 0, 0, 2, 0, 0}; + uint8_t C[4], E[4]; + + ref_u8_i8_u8(2, 3, 2, A, B, E, 4.0); + matmul(2, 3, 2, A, B, C, 4.0); + + for (int i = 0; i < 4; i++) { + int d = E[i] - C[i]; + if (d < 0) d = -d; + if (d > 0) return MUNIT_FAIL; + } + return MUNIT_OK; +} + +static MunitResult test_matmul_u8_i8_u8_medium(const char *name, const char *type) { + const size_t m = 64, n = 64, p = 64; + uint8_t *A = malloc(m * n); + int8_t *B = malloc(n * p); + uint8_t *C = malloc(m * p); + uint8_t *E = malloc(m * p); + if (!A || !B || !C || !E) { + free(A); + free(B); + free(C); + free(E); + return MUNIT_SKIP; + } + + for (size_t i = 0; i < m * n; i++) A[i] = (uint8_t)((i * 7 + 13) % 251); + for (size_t i = 0; i < n * p; i++) B[i] = (int8_t)(((i * 11 + 17) % 211) - 105); + memset(C, 0, m * p); + memset(E, 0, m * p); + + ref_u8_i8_u8(m, n, p, A, B, E, 0.0); + matmul(m, n, p, A, B, C, 0.0); + + for (size_t i = 0; i < m * p; i++) { + int d = (int)E[i] - (int)C[i]; + if (d < 0) d = -d; + if (d > 0) { + free(A); + free(B); + free(C); + free(E); + return MUNIT_FAIL; + } + } + + free(A); + free(B); + free(C); + free(E); + return MUNIT_OK; +} + +static MunitResult test_matmul_f32_f32_f32(const char *name, const char *type) { + const float A[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + const float B[] = {1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f}; + float C[4], E[4]; + + ref_f32_f32_f32(2, 3, 2, A, B, E, 0.0); + matmul(2, 3, 2, A, B, C, 0.0); + + for (int i = 0; i < 4; i++) { + double d = (double)E[i] - (double)C[i]; + if (d < 0) d = -d; + if (d > 1e-5) return MUNIT_FAIL; + } + return MUNIT_OK; +} + +static MunitResult test_matmul_f32_f32_f32_scaled(const char *name, const char *type) { + const float A[] = {8.0f, 16.0f, 24.0f, 32.0f, 40.0f, 48.0f}; + const float B[] = {2.0f, 0.0f, 0.0f, 2.0f, 0.0f, 0.0f}; + float C[4], E[4]; + + ref_f32_f32_f32(2, 3, 2, A, B, E, 4.0); + matmul(2, 3, 2, A, B, C, 4.0); + + for (int i = 0; i < 4; i++) { + double d = (double)E[i] - (double)C[i]; + if (d < 0) d = -d; + if (d > 1e-5) return MUNIT_FAIL; + } + return MUNIT_OK; +} + +static MunitResult test_matmul_f32_f32_f32_medium(const char *name, const char *type) { + const size_t m = 64, n = 64, p = 64; + float *A = malloc(m * n * sizeof(float)); + float *B = malloc(n * p * sizeof(float)); + float *C = malloc(m * p * sizeof(float)); + float *E = malloc(m * p * sizeof(float)); + if (!A || !B || !C || !E) { + free(A); + free(B); + free(C); + free(E); + return MUNIT_SKIP; + } + + for (size_t i = 0; i < m * n; i++) A[i] = (float)((i * 7 + 13) % 251); + for (size_t i = 0; i < n * p; i++) B[i] = (float)(((i * 11 + 17) % 211) - 105); + memset(C, 0, m * p * sizeof(float)); + memset(E, 0, m * p * sizeof(float)); + + ref_f32_f32_f32(m, n, p, A, B, E, 0.0); + matmul(m, n, p, A, B, C, 0.0); + + for (size_t i = 0; i < m * p; i++) { + double d = (double)E[i] - (double)C[i]; + if (d < 0) d = -d; + if (d > 1e-3) { + free(A); + free(B); + free(C); + free(E); + return MUNIT_FAIL; + } + } + + free(A); + free(B); + free(C); + free(E); + return MUNIT_OK; +} + +static MunitResult test_matmul_f64_f64_f64(const char *name, const char *type) { + const double A[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}; + const double B[] = {1.0, 0.0, 0.0, 1.0, 0.0, 0.0}; + double C[4], E[4]; + + ref_f64_f64_f64(2, 3, 2, A, B, E, 0.0); + matmul(2, 3, 2, A, B, C, 0.0); + + for (int i = 0; i < 4; i++) { + double d = E[i] - C[i]; + if (d < 0) d = -d; + if (d > 1e-12) return MUNIT_FAIL; + } + return MUNIT_OK; +} + +static MunitResult test_matmul_f64_f64_f64_scaled(const char *name, const char *type) { + const double A[] = {8.0, 16.0, 24.0, 32.0, 40.0, 48.0}; + const double B[] = {2.0, 0.0, 0.0, 2.0, 0.0, 0.0}; + double C[4], E[4]; + + ref_f64_f64_f64(2, 3, 2, A, B, E, 4.0); + matmul(2, 3, 2, A, B, C, 4.0); + + for (int i = 0; i < 4; i++) { + double d = E[i] - C[i]; + if (d < 0) d = -d; + if (d > 1e-12) return MUNIT_FAIL; + } + return MUNIT_OK; +} + +static MunitResult test_matmul_f64_f64_f64_medium(const char *name, const char *type) { + const size_t m = 64, n = 64, p = 64; + double *A = malloc(m * n * sizeof(double)); + double *B = malloc(n * p * sizeof(double)); + double *C = malloc(m * p * sizeof(double)); + double *E = malloc(m * p * sizeof(double)); + if (!A || !B || !C || !E) { + free(A); + free(B); + free(C); + free(E); + return MUNIT_SKIP; + } + + for (size_t i = 0; i < m * n; i++) A[i] = (double)((i * 7 + 13) % 251); + for (size_t i = 0; i < n * p; i++) B[i] = (double)(((i * 11 + 17) % 211) - 105); + memset(C, 0, m * p * sizeof(double)); + memset(E, 0, m * p * sizeof(double)); + + ref_f64_f64_f64(m, n, p, A, B, E, 0.0); + matmul(m, n, p, A, B, C, 0.0); + + for (size_t i = 0; i < m * p; i++) { + double d = E[i] - C[i]; + if (d < 0) d = -d; + if (d > 1e-9) { + free(A); + free(B); + free(C); + free(E); + return MUNIT_FAIL; + } + } + + free(A); + free(B); + free(C); + free(E); + return MUNIT_OK; +} + +static MunitResult test_matmul_u8_i8_u8_const(const MunitParameter *params, void *data) { + (void)params; + (void)data; + return test_matmul_u8_i8_u8("const", "u8_i8_u8"); +} + +static MunitResult test_matmul_u8_i8_u8_nonconst(const MunitParameter *params, void *data) { + (void)params; + (void)data; + uint8_t A[] = {1, 2, 3, 4, 5, 6}; + int8_t B[] = {1, 0, 0, 1, 0, 0}; + uint8_t C[4], E[4]; + + ref_u8_i8_u8(2, 3, 2, A, B, E, 0.0); + matmul(2, 3, 2, A, B, C, 0.0); + + for (int i = 0; i < 4; i++) { + int d = E[i] - C[i]; + if (d < 0) d = -d; + if (d > 0) return MUNIT_FAIL; + } + return MUNIT_OK; +} + +static MunitResult test_matmul_u8_i8_u8_scaled_const(const MunitParameter *params, void *data) { + (void)params; + (void)data; + return test_matmul_u8_i8_u8_scaled("const", "u8_i8_u8"); +} + +static MunitResult test_matmul_u8_i8_u8_medium_const(const MunitParameter *params, void *data) { + (void)params; + (void)data; + return test_matmul_u8_i8_u8_medium("const", "u8_i8_u8"); +} + +static MunitResult test_matmul_f32_f32_f32_const(const MunitParameter *params, void *data) { + (void)params; + (void)data; + return test_matmul_f32_f32_f32("const", "f32_f32_f32"); +} + +static int matmul_f32_f32_f32_nonconst(size_t m, size_t n, size_t p, float *A, float *B, float *C, double scale) { + return matmul(m, n, p, A, B, C, scale); +} + +static MunitResult test_matmul_f32_f32_f32_nonconst(const MunitParameter *params, void *data) { + (void)params; + (void)data; + float A[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + float B[] = {1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f}; + float C[4], E[4]; + + ref_f32_f32_f32(2, 3, 2, A, B, E, 0.0); + matmul_f32_f32_f32_nonconst(2, 3, 2, A, B, C, 0.0); + + for (int i = 0; i < 4; i++) { + double d = (double)E[i] - (double)C[i]; + if (d < 0) d = -d; + if (d > 1e-5) return MUNIT_FAIL; + } + return MUNIT_OK; +} + +static MunitResult test_matmul_f32_f32_f32_scaled_const(const MunitParameter *params, void *data) { + (void)params; + (void)data; + return test_matmul_f32_f32_f32_scaled("const", "f32_f32_f32"); +} + +static MunitResult test_matmul_f32_f32_f32_medium_const(const MunitParameter *params, void *data) { + (void)params; + (void)data; + return test_matmul_f32_f32_f32_medium("const", "f32_f32_f32"); +} + +static MunitResult test_matmul_f64_f64_f64_const(const MunitParameter *params, void *data) { + (void)params; + (void)data; + return test_matmul_f64_f64_f64("const", "f64_f64_f64"); +} + +static int matmul_f64_f64_f64_nonconst(size_t m, size_t n, size_t p, double *A, double *B, double *C, double scale) { + return matmul(m, n, p, A, B, C, scale); +} + +static MunitResult test_matmul_f64_f64_f64_nonconst(const MunitParameter *params, void *data) { + (void)params; + (void)data; + double A[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}; + double B[] = {1.0, 0.0, 0.0, 1.0, 0.0, 0.0}; + double C[4], E[4]; + + ref_f64_f64_f64(2, 3, 2, A, B, E, 0.0); + matmul_f64_f64_f64_nonconst(2, 3, 2, A, B, C, 0.0); + + for (int i = 0; i < 4; i++) { + double d = E[i] - C[i]; + if (d < 0) d = -d; + if (d > 1e-12) return MUNIT_FAIL; + } + return MUNIT_OK; +} + +static MunitResult test_matmul_f64_f64_f64_scaled_const(const MunitParameter *params, void *data) { + (void)params; + (void)data; + return test_matmul_f64_f64_f64_scaled("const", "f64_f64_f64"); +} + +static MunitResult test_matmul_f64_f64_f64_medium_const(const MunitParameter *params, void *data) { + (void)params; + (void)data; + return test_matmul_f64_f64_f64_medium("const", "f64_f64_f64"); +} + static MunitTest tests[] = { {"/scalar-u8-i8-u8", test_scalar_u8_i8_u8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, {"/scalar-u8-i8-u8-medium", test_scalar_u8_i8_u8_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, @@ -773,6 +1112,22 @@ static MunitTest tests[] = { {"/dispatched-f64-f64-f64-scaled", test_dispatched_f64_f64_f64_scaled, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, {"/dispatched-f64-f64-f64-scaled-medium", test_dispatched_f64_f64_f64_scaled_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, + {"/matmul-u8-i8-u8-const", test_matmul_u8_i8_u8_const, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, + {"/matmul-u8-i8-u8-nonconst", test_matmul_u8_i8_u8_nonconst, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, + {"/matmul-u8-i8-u8-scaled-const", test_matmul_u8_i8_u8_scaled_const, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, + {"/matmul-u8-i8-u8-medium-const", test_matmul_u8_i8_u8_medium_const, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, + {"/matmul-f32-f32-f32-const", test_matmul_f32_f32_f32_const, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, + {"/matmul-f32-f32-f32-nonconst", test_matmul_f32_f32_f32_nonconst, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, + {"/matmul-f32-f32-f32-scaled-const", test_matmul_f32_f32_f32_scaled_const, NULL, NULL, MUNIT_TEST_OPTION_NONE, + NULL}, + {"/matmul-f32-f32-f32-medium-const", test_matmul_f32_f32_f32_medium_const, NULL, NULL, MUNIT_TEST_OPTION_NONE, + NULL}, + {"/matmul-f64-f64-f64-const", test_matmul_f64_f64_f64_const, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, + {"/matmul-f64-f64-f64-nonconst", test_matmul_f64_f64_f64_nonconst, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, + {"/matmul-f64-f64-f64-scaled-const", test_matmul_f64_f64_f64_scaled_const, NULL, NULL, MUNIT_TEST_OPTION_NONE, + NULL}, + {"/matmul-f64-f64-f64-medium-const", test_matmul_f64_f64_f64_medium_const, NULL, NULL, MUNIT_TEST_OPTION_NONE, + NULL}, {NULL, NULL, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}}; static const MunitSuite suite = {"/matmul", tests, NULL, 1, MUNIT_SUITE_OPTION_NONE};