matmul.c

Matrix multiplication helper library
git clone git://git.finwo.net/lib/matmul.c
Log | Files | Refs | README | LICENSE

commit 8cc0cb57ecf252548aa53e68c41deb151b119360
parent 25ddb45ce8dd8da044056d27bbc564863d01c943
Author: finwo <finwo@pm.me>
Date:   Sat, 18 Apr 2026 22:03:41 +0200

Re-implement scaling

Diffstat:
Msrc/matmul.c | 3+++
Mtest/test_matmul.c | 116++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++---
2 files changed, 116 insertions(+), 3 deletions(-)

diff --git a/src/matmul.c b/src/matmul.c @@ -142,6 +142,7 @@ int matmul_scalar_u8_i8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const for (size_t j = jj; j < j_end; j++) { size_t lj = j - jj; int v = acc[li * tj + lj]; + if (scale > 1.0) v = (int)(v / scale); if (v > 255) v = 255; else if (v < 0) @@ -197,6 +198,7 @@ int matmul_avx512vnni_u8_i8_u8(size_t m, size_t n, size_t p, const uint8_t *A, c _mm512_store_si512(tmp, result); for (size_t dj = 0; dj < 16; dj++) { int32_t v = tmp[dj]; + if (scale > 1.0) v = (int32_t)(v / scale); if (v > 255) v = 255; else if (v < 0) @@ -209,6 +211,7 @@ int matmul_avx512vnni_u8_i8_u8(size_t m, size_t n, size_t p, const uint8_t *A, c for (size_t k = 0; k < n; k++) { sum += (int)A[i * n + k] * (int)B[k * p + j]; } + if (scale > 1.0) sum = (int32_t)(sum / scale); if (sum > 255) sum = 255; else if (sum < 0) diff --git a/test/test_matmul.c b/test/test_matmul.c @@ -8,11 +8,12 @@ #include "nemequ/munit.h" #include "test_matmul_simd.h" -static void ref_u8_i8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, uint8_t *C) { +static void ref_u8_i8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, uint8_t *C, double scale) { for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { int sum = 0; for (size_t k = 0; k < n; k++) sum += (int)A[i * n + k] * (int)B[k * p + j]; + if (scale > 1.0) sum = (int)(sum / scale); if (sum > 255) sum = 255; if (sum < 0) sum = 0; C[i * p + j] = (uint8_t)sum; @@ -27,7 +28,7 @@ static MunitResult test_u8_i8_u8_small(const char *name, int8_t B[] = {1, 0, 0, 1, 0, 0}; uint8_t C[4], E[4]; - ref_u8_i8_u8(2, 3, 2, A, B, E); + ref_u8_i8_u8(2, 3, 2, A, B, E, 0.0); matmul_fn(2, 3, 2, A, B, C, 0.0); for (int i = 0; i < 4; i++) { @@ -61,7 +62,7 @@ static MunitResult test_u8_i8_u8_medium(const char *name, memset(C, 0, m * p); memset(E, 0, m * p); - ref_u8_i8_u8(m, n, p, A, B, E); + ref_u8_i8_u8(m, n, p, A, B, E, 0.0); matmul_fn(m, n, p, A, B, C, 0.0); for (size_t i = 0; i < m * p; i++) { @@ -128,15 +129,124 @@ static MunitResult test_dispatched_u8_i8_u8_medium(const MunitParameter *params, return test_u8_i8_u8_medium("dispatched", matmul_u8_i8_u8, 0); } +static MunitResult test_u8_i8_u8_scaled_small(const char *name, + int (*matmul_fn)(size_t, size_t, size_t, const uint8_t *, const int8_t *, + uint8_t *, double), + double epsilon) { + uint8_t A[] = {8, 16, 24, 32, 40, 48}; + int8_t B[] = {2, 0, 0, 2, 0, 0}; + uint8_t C[4], E[4]; + + ref_u8_i8_u8(2, 3, 2, A, B, E, 4.0); + matmul_fn(2, 3, 2, A, B, C, 4.0); + + for (int i = 0; i < 4; i++) { + int d = E[i] - C[i]; + if (d < 0) d = -d; + if (d > (int)epsilon) return MUNIT_FAIL; + } + return MUNIT_OK; +} + +static MunitResult test_u8_i8_u8_scaled_medium(const char *name, + int (*matmul_fn)(size_t, size_t, size_t, const uint8_t *, const int8_t *, + uint8_t *, double), + double epsilon) { + const size_t m = 64, n = 64, p = 64; + uint8_t *A = malloc(m * n); + int8_t *B = malloc(n * p); + uint8_t *C = malloc(m * p); + uint8_t *E = malloc(m * p); + if (!A || !B || !C || !E) { + free(A); + free(B); + free(C); + free(E); + return MUNIT_SKIP; + } + + for (size_t i = 0; i < m * n; i++) A[i] = (uint8_t)((i * 7 + 13) % 251); + for (size_t i = 0; i < n * p; i++) B[i] = (int8_t)(((i * 11 + 17) % 211) - 105); + memset(C, 0, m * p); + memset(E, 0, m * p); + + ref_u8_i8_u8(m, n, p, A, B, E, 8.0); + matmul_fn(m, n, p, A, B, C, 8.0); + + for (size_t i = 0; i < m * p; i++) { + int d = (int)E[i] - (int)C[i]; + if (d < 0) d = -d; + if (d > (int)epsilon) { + free(A); + free(B); + free(C); + free(E); + return MUNIT_FAIL; + } + } + + free(A); + free(B); + free(C); + free(E); + return MUNIT_OK; +} + +static MunitResult test_scalar_u8_i8_u8_scaled(const MunitParameter *params, void *data) { + (void)params; + (void)data; + return test_u8_i8_u8_scaled_small("scalar", matmul_scalar_u8_i8_u8, 0); +} + +static MunitResult test_scalar_u8_i8_u8_scaled_medium(const MunitParameter *params, void *data) { + (void)params; + (void)data; + return test_u8_i8_u8_scaled_medium("scalar", matmul_scalar_u8_i8_u8, 0); +} + +#ifdef __AVX512VNNI__ +static MunitResult test_avx512vnni_u8_i8_u8_scaled(const MunitParameter *params, void *data) { + (void)params; + (void)data; + return test_u8_i8_u8_scaled_small("avx512vnni", matmul_avx512vnni_u8_i8_u8, 0); +} + +static MunitResult test_avx512vnni_u8_i8_u8_scaled_medium(const MunitParameter *params, void *data) { + (void)params; + (void)data; + return test_u8_i8_u8_scaled_medium("avx512vnni", matmul_avx512vnni_u8_i8_u8, 0); +} +#endif + +static MunitResult test_dispatched_u8_i8_u8_scaled(const MunitParameter *params, void *data) { + (void)params; + (void)data; + return test_u8_i8_u8_scaled_small("dispatched", matmul_u8_i8_u8, 0); +} + +static MunitResult test_dispatched_u8_i8_u8_scaled_medium(const MunitParameter *params, void *data) { + (void)params; + (void)data; + return test_u8_i8_u8_scaled_medium("dispatched", matmul_u8_i8_u8, 0); +} + static MunitTest tests[] = { {"/scalar-u8-i8-u8", test_scalar_u8_i8_u8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, {"/scalar-u8-i8-u8-medium", test_scalar_u8_i8_u8_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, + {"/scalar-u8-i8-u8-scaled", test_scalar_u8_i8_u8_scaled, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, + {"/scalar-u8-i8-u8-scaled-medium", test_scalar_u8_i8_u8_scaled_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, #ifdef __AVX512VNNI__ {"/avx512vnni-u8-i8-u8", test_avx512vnni_u8_i8_u8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, {"/avx512vnni-u8-i8-u8-medium", test_avx512vnni_u8_i8_u8_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, + {"/avx512vnni-u8-i8-u8-scaled", test_avx512vnni_u8_i8_u8_scaled, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, + {"/avx512vnni-u8-i8-u8-scaled-medium", test_avx512vnni_u8_i8_u8_scaled_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, + NULL}, #endif {"/dispatched-u8-i8-u8", test_dispatched_u8_i8_u8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, {"/dispatched-u8-i8-u8-medium", test_dispatched_u8_i8_u8_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, + {"/dispatched-u8-i8-u8-scaled", test_dispatched_u8_i8_u8_scaled, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, + {"/dispatched-u8-i8-u8-scaled-medium", test_dispatched_u8_i8_u8_scaled_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, + NULL}, {NULL, NULL, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}}; static const MunitSuite suite = {"/matmul", tests, NULL, 1, MUNIT_SUITE_OPTION_NONE};