commit a61e69663358bb8e49a1e625b4b564f4b1e756f9
parent 2ac74e34bb55a7f60b8a1ee4e9d68ce93924efa8
Author: finwo <finwo@pm.me>
Date: Tue, 21 Apr 2026 21:35:02 +0200
Macro tested now
Diffstat:
2 files changed, 374 insertions(+), 3 deletions(-)
diff --git a/src/matmul.h b/src/matmul.h
@@ -45,9 +45,25 @@
extern "C" {
#endif
-extern int (*matmul_u8_i8_u8)(size_t, size_t, size_t, const uint8_t *, const int8_t *, uint8_t *, double);
-extern int (*matmul_f32_f32_f32)(size_t, size_t, size_t, const float *, const float *, float *, double);
-extern int (*matmul_f64_f64_f64)(size_t, size_t, size_t, const double *, const double *, double *, double);
+#define __matmul_TYPE_u8 uint8_t
+#define __matmul_TYPE_i8 int8_t
+#define __matmul_TYPE_f32 float
+#define __matmul_TYPE_f64 double
+
+#define __matmul_EXTERN(A, B, C) extern int (*matmul_##A##_##B##_##C)(size_t, size_t, size_t, const __matmul_TYPE_##A *, const __matmul_TYPE_##B *, __matmul_TYPE_##C *, double);
+
+#define matmul_externs \
+ __matmul_EXTERN(u8 , i8 , u8) __matmul_EXTERN(u8 , i8 , f32) __matmul_EXTERN(u8 , i8 , f64) \
+ __matmul_EXTERN(u8 , f32, u8) __matmul_EXTERN(u8 , f32, f32) __matmul_EXTERN(u8 , f32, f64) \
+ __matmul_EXTERN(u8 , f64, u8) __matmul_EXTERN(u8 , f64, f32) __matmul_EXTERN(u8 , f64, f64) \
+ __matmul_EXTERN(f32, i8 , u8) __matmul_EXTERN(f32, i8 , f32) __matmul_EXTERN(f32, i8 , f64) \
+ __matmul_EXTERN(f32, f32, u8) __matmul_EXTERN(f32, f32, f32) __matmul_EXTERN(f32, f32, f64) \
+ __matmul_EXTERN(f32, f64, u8) __matmul_EXTERN(f32, f64, f32) __matmul_EXTERN(f32, f64, f64) \
+ __matmul_EXTERN(f64, i8 , u8) __matmul_EXTERN(f64, i8 , f32) __matmul_EXTERN(f64, i8 , f64) \
+ __matmul_EXTERN(f64, f32, u8) __matmul_EXTERN(f64, f32, f32) __matmul_EXTERN(f64, f32, f64) \
+ __matmul_EXTERN(f64, f64, u8) __matmul_EXTERN(f64, f64, f32) __matmul_EXTERN(f64, f64, f64)
+
+matmul_externs
#define __matmul_C(type_a,type_b) \
_Generic((C), \
diff --git a/test/test_matmul.c b/test/test_matmul.c
@@ -710,6 +710,345 @@ static MunitResult test_dispatched_f64_f64_f64_scaled_medium(const MunitParamete
return test_f64_f64_f64_scaled_medium("dispatched", matmul_f64_f64_f64, 1e-9);
}
+/* ========================================================================== */
+/* matmul generic macro tests */
+/* ========================================================================== */
+
+static MunitResult test_matmul_u8_i8_u8(const char *name, const char *type) {
+ const uint8_t A[] = {1, 2, 3, 4, 5, 6};
+ const int8_t B[] = {1, 0, 0, 1, 0, 0};
+ uint8_t C[4], E[4];
+
+ ref_u8_i8_u8(2, 3, 2, A, B, E, 0.0);
+ matmul(2, 3, 2, A, B, C, 0.0);
+
+ for (int i = 0; i < 4; i++) {
+ int d = E[i] - C[i];
+ if (d < 0) d = -d;
+ if (d > 0) return MUNIT_FAIL;
+ }
+ return MUNIT_OK;
+}
+
+static MunitResult test_matmul_u8_i8_u8_scaled(const char *name, const char *type) {
+ const uint8_t A[] = {8, 16, 24, 32, 40, 48};
+ const int8_t B[] = {2, 0, 0, 2, 0, 0};
+ uint8_t C[4], E[4];
+
+ ref_u8_i8_u8(2, 3, 2, A, B, E, 4.0);
+ matmul(2, 3, 2, A, B, C, 4.0);
+
+ for (int i = 0; i < 4; i++) {
+ int d = E[i] - C[i];
+ if (d < 0) d = -d;
+ if (d > 0) return MUNIT_FAIL;
+ }
+ return MUNIT_OK;
+}
+
+static MunitResult test_matmul_u8_i8_u8_medium(const char *name, const char *type) {
+ const size_t m = 64, n = 64, p = 64;
+ uint8_t *A = malloc(m * n);
+ int8_t *B = malloc(n * p);
+ uint8_t *C = malloc(m * p);
+ uint8_t *E = malloc(m * p);
+ if (!A || !B || !C || !E) {
+ free(A);
+ free(B);
+ free(C);
+ free(E);
+ return MUNIT_SKIP;
+ }
+
+ for (size_t i = 0; i < m * n; i++) A[i] = (uint8_t)((i * 7 + 13) % 251);
+ for (size_t i = 0; i < n * p; i++) B[i] = (int8_t)(((i * 11 + 17) % 211) - 105);
+ memset(C, 0, m * p);
+ memset(E, 0, m * p);
+
+ ref_u8_i8_u8(m, n, p, A, B, E, 0.0);
+ matmul(m, n, p, A, B, C, 0.0);
+
+ for (size_t i = 0; i < m * p; i++) {
+ int d = (int)E[i] - (int)C[i];
+ if (d < 0) d = -d;
+ if (d > 0) {
+ free(A);
+ free(B);
+ free(C);
+ free(E);
+ return MUNIT_FAIL;
+ }
+ }
+
+ free(A);
+ free(B);
+ free(C);
+ free(E);
+ return MUNIT_OK;
+}
+
+static MunitResult test_matmul_f32_f32_f32(const char *name, const char *type) {
+ const float A[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
+ const float B[] = {1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f};
+ float C[4], E[4];
+
+ ref_f32_f32_f32(2, 3, 2, A, B, E, 0.0);
+ matmul(2, 3, 2, A, B, C, 0.0);
+
+ for (int i = 0; i < 4; i++) {
+ double d = (double)E[i] - (double)C[i];
+ if (d < 0) d = -d;
+ if (d > 1e-5) return MUNIT_FAIL;
+ }
+ return MUNIT_OK;
+}
+
+static MunitResult test_matmul_f32_f32_f32_scaled(const char *name, const char *type) {
+ const float A[] = {8.0f, 16.0f, 24.0f, 32.0f, 40.0f, 48.0f};
+ const float B[] = {2.0f, 0.0f, 0.0f, 2.0f, 0.0f, 0.0f};
+ float C[4], E[4];
+
+ ref_f32_f32_f32(2, 3, 2, A, B, E, 4.0);
+ matmul(2, 3, 2, A, B, C, 4.0);
+
+ for (int i = 0; i < 4; i++) {
+ double d = (double)E[i] - (double)C[i];
+ if (d < 0) d = -d;
+ if (d > 1e-5) return MUNIT_FAIL;
+ }
+ return MUNIT_OK;
+}
+
+static MunitResult test_matmul_f32_f32_f32_medium(const char *name, const char *type) {
+ const size_t m = 64, n = 64, p = 64;
+ float *A = malloc(m * n * sizeof(float));
+ float *B = malloc(n * p * sizeof(float));
+ float *C = malloc(m * p * sizeof(float));
+ float *E = malloc(m * p * sizeof(float));
+ if (!A || !B || !C || !E) {
+ free(A);
+ free(B);
+ free(C);
+ free(E);
+ return MUNIT_SKIP;
+ }
+
+ for (size_t i = 0; i < m * n; i++) A[i] = (float)((i * 7 + 13) % 251);
+ for (size_t i = 0; i < n * p; i++) B[i] = (float)(((i * 11 + 17) % 211) - 105);
+ memset(C, 0, m * p * sizeof(float));
+ memset(E, 0, m * p * sizeof(float));
+
+ ref_f32_f32_f32(m, n, p, A, B, E, 0.0);
+ matmul(m, n, p, A, B, C, 0.0);
+
+ for (size_t i = 0; i < m * p; i++) {
+ double d = (double)E[i] - (double)C[i];
+ if (d < 0) d = -d;
+ if (d > 1e-3) {
+ free(A);
+ free(B);
+ free(C);
+ free(E);
+ return MUNIT_FAIL;
+ }
+ }
+
+ free(A);
+ free(B);
+ free(C);
+ free(E);
+ return MUNIT_OK;
+}
+
+static MunitResult test_matmul_f64_f64_f64(const char *name, const char *type) {
+ const double A[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
+ const double B[] = {1.0, 0.0, 0.0, 1.0, 0.0, 0.0};
+ double C[4], E[4];
+
+ ref_f64_f64_f64(2, 3, 2, A, B, E, 0.0);
+ matmul(2, 3, 2, A, B, C, 0.0);
+
+ for (int i = 0; i < 4; i++) {
+ double d = E[i] - C[i];
+ if (d < 0) d = -d;
+ if (d > 1e-12) return MUNIT_FAIL;
+ }
+ return MUNIT_OK;
+}
+
+static MunitResult test_matmul_f64_f64_f64_scaled(const char *name, const char *type) {
+ const double A[] = {8.0, 16.0, 24.0, 32.0, 40.0, 48.0};
+ const double B[] = {2.0, 0.0, 0.0, 2.0, 0.0, 0.0};
+ double C[4], E[4];
+
+ ref_f64_f64_f64(2, 3, 2, A, B, E, 4.0);
+ matmul(2, 3, 2, A, B, C, 4.0);
+
+ for (int i = 0; i < 4; i++) {
+ double d = E[i] - C[i];
+ if (d < 0) d = -d;
+ if (d > 1e-12) return MUNIT_FAIL;
+ }
+ return MUNIT_OK;
+}
+
+static MunitResult test_matmul_f64_f64_f64_medium(const char *name, const char *type) {
+ const size_t m = 64, n = 64, p = 64;
+ double *A = malloc(m * n * sizeof(double));
+ double *B = malloc(n * p * sizeof(double));
+ double *C = malloc(m * p * sizeof(double));
+ double *E = malloc(m * p * sizeof(double));
+ if (!A || !B || !C || !E) {
+ free(A);
+ free(B);
+ free(C);
+ free(E);
+ return MUNIT_SKIP;
+ }
+
+ for (size_t i = 0; i < m * n; i++) A[i] = (double)((i * 7 + 13) % 251);
+ for (size_t i = 0; i < n * p; i++) B[i] = (double)(((i * 11 + 17) % 211) - 105);
+ memset(C, 0, m * p * sizeof(double));
+ memset(E, 0, m * p * sizeof(double));
+
+ ref_f64_f64_f64(m, n, p, A, B, E, 0.0);
+ matmul(m, n, p, A, B, C, 0.0);
+
+ for (size_t i = 0; i < m * p; i++) {
+ double d = E[i] - C[i];
+ if (d < 0) d = -d;
+ if (d > 1e-9) {
+ free(A);
+ free(B);
+ free(C);
+ free(E);
+ return MUNIT_FAIL;
+ }
+ }
+
+ free(A);
+ free(B);
+ free(C);
+ free(E);
+ return MUNIT_OK;
+}
+
+static MunitResult test_matmul_u8_i8_u8_const(const MunitParameter *params, void *data) {
+ (void)params;
+ (void)data;
+ return test_matmul_u8_i8_u8("const", "u8_i8_u8");
+}
+
+static MunitResult test_matmul_u8_i8_u8_nonconst(const MunitParameter *params, void *data) {
+ (void)params;
+ (void)data;
+ uint8_t A[] = {1, 2, 3, 4, 5, 6};
+ int8_t B[] = {1, 0, 0, 1, 0, 0};
+ uint8_t C[4], E[4];
+
+ ref_u8_i8_u8(2, 3, 2, A, B, E, 0.0);
+ matmul(2, 3, 2, A, B, C, 0.0);
+
+ for (int i = 0; i < 4; i++) {
+ int d = E[i] - C[i];
+ if (d < 0) d = -d;
+ if (d > 0) return MUNIT_FAIL;
+ }
+ return MUNIT_OK;
+}
+
+static MunitResult test_matmul_u8_i8_u8_scaled_const(const MunitParameter *params, void *data) {
+ (void)params;
+ (void)data;
+ return test_matmul_u8_i8_u8_scaled("const", "u8_i8_u8");
+}
+
+static MunitResult test_matmul_u8_i8_u8_medium_const(const MunitParameter *params, void *data) {
+ (void)params;
+ (void)data;
+ return test_matmul_u8_i8_u8_medium("const", "u8_i8_u8");
+}
+
+static MunitResult test_matmul_f32_f32_f32_const(const MunitParameter *params, void *data) {
+ (void)params;
+ (void)data;
+ return test_matmul_f32_f32_f32("const", "f32_f32_f32");
+}
+
+static int matmul_f32_f32_f32_nonconst(size_t m, size_t n, size_t p, float *A, float *B, float *C, double scale) {
+ return matmul(m, n, p, A, B, C, scale);
+}
+
+static MunitResult test_matmul_f32_f32_f32_nonconst(const MunitParameter *params, void *data) {
+ (void)params;
+ (void)data;
+ float A[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
+ float B[] = {1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f};
+ float C[4], E[4];
+
+ ref_f32_f32_f32(2, 3, 2, A, B, E, 0.0);
+ matmul_f32_f32_f32_nonconst(2, 3, 2, A, B, C, 0.0);
+
+ for (int i = 0; i < 4; i++) {
+ double d = (double)E[i] - (double)C[i];
+ if (d < 0) d = -d;
+ if (d > 1e-5) return MUNIT_FAIL;
+ }
+ return MUNIT_OK;
+}
+
+static MunitResult test_matmul_f32_f32_f32_scaled_const(const MunitParameter *params, void *data) {
+ (void)params;
+ (void)data;
+ return test_matmul_f32_f32_f32_scaled("const", "f32_f32_f32");
+}
+
+static MunitResult test_matmul_f32_f32_f32_medium_const(const MunitParameter *params, void *data) {
+ (void)params;
+ (void)data;
+ return test_matmul_f32_f32_f32_medium("const", "f32_f32_f32");
+}
+
+static MunitResult test_matmul_f64_f64_f64_const(const MunitParameter *params, void *data) {
+ (void)params;
+ (void)data;
+ return test_matmul_f64_f64_f64("const", "f64_f64_f64");
+}
+
+static int matmul_f64_f64_f64_nonconst(size_t m, size_t n, size_t p, double *A, double *B, double *C, double scale) {
+ return matmul(m, n, p, A, B, C, scale);
+}
+
+static MunitResult test_matmul_f64_f64_f64_nonconst(const MunitParameter *params, void *data) {
+ (void)params;
+ (void)data;
+ double A[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
+ double B[] = {1.0, 0.0, 0.0, 1.0, 0.0, 0.0};
+ double C[4], E[4];
+
+ ref_f64_f64_f64(2, 3, 2, A, B, E, 0.0);
+ matmul_f64_f64_f64_nonconst(2, 3, 2, A, B, C, 0.0);
+
+ for (int i = 0; i < 4; i++) {
+ double d = E[i] - C[i];
+ if (d < 0) d = -d;
+ if (d > 1e-12) return MUNIT_FAIL;
+ }
+ return MUNIT_OK;
+}
+
+static MunitResult test_matmul_f64_f64_f64_scaled_const(const MunitParameter *params, void *data) {
+ (void)params;
+ (void)data;
+ return test_matmul_f64_f64_f64_scaled("const", "f64_f64_f64");
+}
+
+static MunitResult test_matmul_f64_f64_f64_medium_const(const MunitParameter *params, void *data) {
+ (void)params;
+ (void)data;
+ return test_matmul_f64_f64_f64_medium("const", "f64_f64_f64");
+}
+
static MunitTest tests[] = {
{"/scalar-u8-i8-u8", test_scalar_u8_i8_u8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
{"/scalar-u8-i8-u8-medium", test_scalar_u8_i8_u8_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
@@ -773,6 +1112,22 @@ static MunitTest tests[] = {
{"/dispatched-f64-f64-f64-scaled", test_dispatched_f64_f64_f64_scaled, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
{"/dispatched-f64-f64-f64-scaled-medium", test_dispatched_f64_f64_f64_scaled_medium, NULL, NULL,
MUNIT_TEST_OPTION_NONE, NULL},
+ {"/matmul-u8-i8-u8-const", test_matmul_u8_i8_u8_const, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
+ {"/matmul-u8-i8-u8-nonconst", test_matmul_u8_i8_u8_nonconst, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
+ {"/matmul-u8-i8-u8-scaled-const", test_matmul_u8_i8_u8_scaled_const, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
+ {"/matmul-u8-i8-u8-medium-const", test_matmul_u8_i8_u8_medium_const, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
+ {"/matmul-f32-f32-f32-const", test_matmul_f32_f32_f32_const, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
+ {"/matmul-f32-f32-f32-nonconst", test_matmul_f32_f32_f32_nonconst, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
+ {"/matmul-f32-f32-f32-scaled-const", test_matmul_f32_f32_f32_scaled_const, NULL, NULL, MUNIT_TEST_OPTION_NONE,
+ NULL},
+ {"/matmul-f32-f32-f32-medium-const", test_matmul_f32_f32_f32_medium_const, NULL, NULL, MUNIT_TEST_OPTION_NONE,
+ NULL},
+ {"/matmul-f64-f64-f64-const", test_matmul_f64_f64_f64_const, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
+ {"/matmul-f64-f64-f64-nonconst", test_matmul_f64_f64_f64_nonconst, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
+ {"/matmul-f64-f64-f64-scaled-const", test_matmul_f64_f64_f64_scaled_const, NULL, NULL, MUNIT_TEST_OPTION_NONE,
+ NULL},
+ {"/matmul-f64-f64-f64-medium-const", test_matmul_f64_f64_f64_medium_const, NULL, NULL, MUNIT_TEST_OPTION_NONE,
+ NULL},
{NULL, NULL, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}};
static const MunitSuite suite = {"/matmul", tests, NULL, 1, MUNIT_SUITE_OPTION_NONE};