commit 8cc0cb57ecf252548aa53e68c41deb151b119360
parent 25ddb45ce8dd8da044056d27bbc564863d01c943
Author: finwo <finwo@pm.me>
Date: Sat, 18 Apr 2026 22:03:41 +0200
Re-implement scaling
Diffstat:
2 files changed, 116 insertions(+), 3 deletions(-)
diff --git a/src/matmul.c b/src/matmul.c
@@ -142,6 +142,7 @@ int matmul_scalar_u8_i8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const
for (size_t j = jj; j < j_end; j++) {
size_t lj = j - jj;
int v = acc[li * tj + lj];
+ if (scale > 1.0) v = (int)(v / scale);
if (v > 255)
v = 255;
else if (v < 0)
@@ -197,6 +198,7 @@ int matmul_avx512vnni_u8_i8_u8(size_t m, size_t n, size_t p, const uint8_t *A, c
_mm512_store_si512(tmp, result);
for (size_t dj = 0; dj < 16; dj++) {
int32_t v = tmp[dj];
+ if (scale > 1.0) v = (int32_t)(v / scale);
if (v > 255)
v = 255;
else if (v < 0)
@@ -209,6 +211,7 @@ int matmul_avx512vnni_u8_i8_u8(size_t m, size_t n, size_t p, const uint8_t *A, c
for (size_t k = 0; k < n; k++) {
sum += (int)A[i * n + k] * (int)B[k * p + j];
}
+ if (scale > 1.0) sum = (int32_t)(sum / scale);
if (sum > 255)
sum = 255;
else if (sum < 0)
diff --git a/test/test_matmul.c b/test/test_matmul.c
@@ -8,11 +8,12 @@
#include "nemequ/munit.h"
#include "test_matmul_simd.h"
-static void ref_u8_i8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, uint8_t *C) {
+static void ref_u8_i8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, uint8_t *C, double scale) {
for (size_t i = 0; i < m; i++)
for (size_t j = 0; j < p; j++) {
int sum = 0;
for (size_t k = 0; k < n; k++) sum += (int)A[i * n + k] * (int)B[k * p + j];
+ if (scale > 1.0) sum = (int)(sum / scale);
if (sum > 255) sum = 255;
if (sum < 0) sum = 0;
C[i * p + j] = (uint8_t)sum;
@@ -27,7 +28,7 @@ static MunitResult test_u8_i8_u8_small(const char *name,
int8_t B[] = {1, 0, 0, 1, 0, 0};
uint8_t C[4], E[4];
- ref_u8_i8_u8(2, 3, 2, A, B, E);
+ ref_u8_i8_u8(2, 3, 2, A, B, E, 0.0);
matmul_fn(2, 3, 2, A, B, C, 0.0);
for (int i = 0; i < 4; i++) {
@@ -61,7 +62,7 @@ static MunitResult test_u8_i8_u8_medium(const char *name,
memset(C, 0, m * p);
memset(E, 0, m * p);
- ref_u8_i8_u8(m, n, p, A, B, E);
+ ref_u8_i8_u8(m, n, p, A, B, E, 0.0);
matmul_fn(m, n, p, A, B, C, 0.0);
for (size_t i = 0; i < m * p; i++) {
@@ -128,15 +129,124 @@ static MunitResult test_dispatched_u8_i8_u8_medium(const MunitParameter *params,
return test_u8_i8_u8_medium("dispatched", matmul_u8_i8_u8, 0);
}
+static MunitResult test_u8_i8_u8_scaled_small(const char *name,
+ int (*matmul_fn)(size_t, size_t, size_t, const uint8_t *, const int8_t *,
+ uint8_t *, double),
+ double epsilon) {
+ uint8_t A[] = {8, 16, 24, 32, 40, 48};
+ int8_t B[] = {2, 0, 0, 2, 0, 0};
+ uint8_t C[4], E[4];
+
+ ref_u8_i8_u8(2, 3, 2, A, B, E, 4.0);
+ matmul_fn(2, 3, 2, A, B, C, 4.0);
+
+ for (int i = 0; i < 4; i++) {
+ int d = E[i] - C[i];
+ if (d < 0) d = -d;
+ if (d > (int)epsilon) return MUNIT_FAIL;
+ }
+ return MUNIT_OK;
+}
+
+static MunitResult test_u8_i8_u8_scaled_medium(const char *name,
+ int (*matmul_fn)(size_t, size_t, size_t, const uint8_t *, const int8_t *,
+ uint8_t *, double),
+ double epsilon) {
+ const size_t m = 64, n = 64, p = 64;
+ uint8_t *A = malloc(m * n);
+ int8_t *B = malloc(n * p);
+ uint8_t *C = malloc(m * p);
+ uint8_t *E = malloc(m * p);
+ if (!A || !B || !C || !E) {
+ free(A);
+ free(B);
+ free(C);
+ free(E);
+ return MUNIT_SKIP;
+ }
+
+ for (size_t i = 0; i < m * n; i++) A[i] = (uint8_t)((i * 7 + 13) % 251);
+ for (size_t i = 0; i < n * p; i++) B[i] = (int8_t)(((i * 11 + 17) % 211) - 105);
+ memset(C, 0, m * p);
+ memset(E, 0, m * p);
+
+ ref_u8_i8_u8(m, n, p, A, B, E, 8.0);
+ matmul_fn(m, n, p, A, B, C, 8.0);
+
+ for (size_t i = 0; i < m * p; i++) {
+ int d = (int)E[i] - (int)C[i];
+ if (d < 0) d = -d;
+ if (d > (int)epsilon) {
+ free(A);
+ free(B);
+ free(C);
+ free(E);
+ return MUNIT_FAIL;
+ }
+ }
+
+ free(A);
+ free(B);
+ free(C);
+ free(E);
+ return MUNIT_OK;
+}
+
+static MunitResult test_scalar_u8_i8_u8_scaled(const MunitParameter *params, void *data) {
+ (void)params;
+ (void)data;
+ return test_u8_i8_u8_scaled_small("scalar", matmul_scalar_u8_i8_u8, 0);
+}
+
+static MunitResult test_scalar_u8_i8_u8_scaled_medium(const MunitParameter *params, void *data) {
+ (void)params;
+ (void)data;
+ return test_u8_i8_u8_scaled_medium("scalar", matmul_scalar_u8_i8_u8, 0);
+}
+
+#ifdef __AVX512VNNI__
+static MunitResult test_avx512vnni_u8_i8_u8_scaled(const MunitParameter *params, void *data) {
+ (void)params;
+ (void)data;
+ return test_u8_i8_u8_scaled_small("avx512vnni", matmul_avx512vnni_u8_i8_u8, 0);
+}
+
+static MunitResult test_avx512vnni_u8_i8_u8_scaled_medium(const MunitParameter *params, void *data) {
+ (void)params;
+ (void)data;
+ return test_u8_i8_u8_scaled_medium("avx512vnni", matmul_avx512vnni_u8_i8_u8, 0);
+}
+#endif
+
+static MunitResult test_dispatched_u8_i8_u8_scaled(const MunitParameter *params, void *data) {
+ (void)params;
+ (void)data;
+ return test_u8_i8_u8_scaled_small("dispatched", matmul_u8_i8_u8, 0);
+}
+
+static MunitResult test_dispatched_u8_i8_u8_scaled_medium(const MunitParameter *params, void *data) {
+ (void)params;
+ (void)data;
+ return test_u8_i8_u8_scaled_medium("dispatched", matmul_u8_i8_u8, 0);
+}
+
static MunitTest tests[] = {
{"/scalar-u8-i8-u8", test_scalar_u8_i8_u8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
{"/scalar-u8-i8-u8-medium", test_scalar_u8_i8_u8_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
+ {"/scalar-u8-i8-u8-scaled", test_scalar_u8_i8_u8_scaled, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
+ {"/scalar-u8-i8-u8-scaled-medium", test_scalar_u8_i8_u8_scaled_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
#ifdef __AVX512VNNI__
{"/avx512vnni-u8-i8-u8", test_avx512vnni_u8_i8_u8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
{"/avx512vnni-u8-i8-u8-medium", test_avx512vnni_u8_i8_u8_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
+ {"/avx512vnni-u8-i8-u8-scaled", test_avx512vnni_u8_i8_u8_scaled, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
+ {"/avx512vnni-u8-i8-u8-scaled-medium", test_avx512vnni_u8_i8_u8_scaled_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE,
+ NULL},
#endif
{"/dispatched-u8-i8-u8", test_dispatched_u8_i8_u8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
{"/dispatched-u8-i8-u8-medium", test_dispatched_u8_i8_u8_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
+ {"/dispatched-u8-i8-u8-scaled", test_dispatched_u8_i8_u8_scaled, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
+ {"/dispatched-u8-i8-u8-scaled-medium", test_dispatched_u8_i8_u8_scaled_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE,
+ NULL},
{NULL, NULL, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}};
static const MunitSuite suite = {"/matmul", tests, NULL, 1, MUNIT_SUITE_OPTION_NONE};