commit 11f476190e1bacace411948da4e194c3f6080a93
parent 8f294137c31dbe5d5e59714450bc9eee49988b63
Author: finwo <finwo@pm.me>
Date: Sat, 18 Apr 2026 04:36:31 +0200
Hard-define the desired public api
Diffstat:
| M | Makefile | | | 18 | ++++++++++++++---- |
| M | src/matmul.h | | | 157 | ++++--------------------------------------------------------------------------- |
2 files changed, 22 insertions(+), 153 deletions(-)
diff --git a/Makefile b/Makefile
@@ -9,17 +9,26 @@ CFLAGS?=-Wall -std=c99
include lib/.dep/config.mk
-CGLAGS+=$(INCLUDES)
+CFLAGS+=-D_DEFAULT_SOURCE -mfma -mavx2 -mavxvnni -mavx512f -mavx512vnni -mavx512bw -fopenmp -O3
+
+CFLAGS+=$(INCLUDES)
OBJ=$(SRC:.c=.o)
BIN=\
- test_matmul
+ test_matmul \
+ benchmark
default: $(BIN)
-$(BIN): $(OBJ) test/$(BIN:=.o)
- $(CC) $(CFLAGS) $(OBJ) test/$@.o -o $@
+test/%.o: test/%.c
+ $(CC) $(CFLAGS) -c $< -o $@
+
+test_matmul: $(OBJ) test/test_matmul.o
+ $(CC) $(CFLAGS) $(OBJ) test/test_matmul.o -o $@ -lrt -fopenmp
+
+benchmark: $(OBJ) test/benchmark.o
+ $(CC) $(CFLAGS) $(OBJ) test/benchmark.o -o $@ -lrt -fopenmp
.PHONY: clean
clean:
@@ -32,3 +41,4 @@ clean:
.PHONY: format
format:
$(FIND) src/ -type f \( -name '*.c' -o -name '*.h' \) -exec clang-format -i {} +
+ $(FIND) test/ -type f \( -name '*.c' -o -name '*.h' \) -exec clang-format -i {} +
diff --git a/src/matmul.h b/src/matmul.h
@@ -45,157 +45,16 @@
extern "C" {
#endif
-extern int (*matmul_f32_f32_f32)(size_t, size_t, size_t, const float *, const float *, float *, double);
-extern int (*matmul_f32_f32_f64)(size_t, size_t, size_t, const float *, const float *, double *, double);
-extern int (*matmul_f32_f32_i8)(size_t, size_t, size_t, const float *, const float *, int8_t *, double);
-extern int (*matmul_f32_f32_u8)(size_t, size_t, size_t, const float *, const float *, uint8_t *, double);
-extern int (*matmul_f32_f64_f32)(size_t, size_t, size_t, const float *, const double *, float *, double);
-extern int (*matmul_f32_f64_f64)(size_t, size_t, size_t, const float *, const double *, double *, double);
-extern int (*matmul_f32_f64_i8)(size_t, size_t, size_t, const float *, const double *, int8_t *, double);
-extern int (*matmul_f32_f64_u8)(size_t, size_t, size_t, const float *, const double *, uint8_t *, double);
-extern int (*matmul_f32_i8_f32)(size_t, size_t, size_t, const float *, const int8_t *, float *, double);
-extern int (*matmul_f32_i8_f64)(size_t, size_t, size_t, const float *, const int8_t *, double *, double);
-extern int (*matmul_f32_i8_i8)(size_t, size_t, size_t, const float *, const int8_t *, int8_t *, double);
-extern int (*matmul_f32_i8_u8)(size_t, size_t, size_t, const float *, const int8_t *, uint8_t *, double);
-extern int (*matmul_f32_u8_f32)(size_t, size_t, size_t, const float *, const uint8_t *, float *, double);
-extern int (*matmul_f32_u8_f64)(size_t, size_t, size_t, const float *, const uint8_t *, double *, double);
-extern int (*matmul_f32_u8_i8)(size_t, size_t, size_t, const float *, const uint8_t *, int8_t *, double);
-extern int (*matmul_f32_u8_u8)(size_t, size_t, size_t, const float *, const uint8_t *, uint8_t *, double);
-extern int (*matmul_f64_f32_f32)(size_t, size_t, size_t, const double *, const float *, float *, double);
-extern int (*matmul_f64_f32_f64)(size_t, size_t, size_t, const double *, const float *, double *, double);
-extern int (*matmul_f64_f32_i8)(size_t, size_t, size_t, const double *, const float *, int8_t *, double);
-extern int (*matmul_f64_f32_u8)(size_t, size_t, size_t, const double *, const float *, uint8_t *, double);
-extern int (*matmul_f64_f64_f32)(size_t, size_t, size_t, const double *, const double *, float *, double);
-extern int (*matmul_f64_f64_f64)(size_t, size_t, size_t, const double *, const double *, double *, double);
-extern int (*matmul_f64_f64_i8)(size_t, size_t, size_t, const double *, const double *, int8_t *, double);
-extern int (*matmul_f64_f64_u8)(size_t, size_t, size_t, const double *, const double *, uint8_t *, double);
-extern int (*matmul_f64_i8_f32)(size_t, size_t, size_t, const double *, const int8_t *, float *, double);
-extern int (*matmul_f64_i8_f64)(size_t, size_t, size_t, const double *, const int8_t *, double *, double);
-extern int (*matmul_f64_i8_i8)(size_t, size_t, size_t, const double *, const int8_t *, int8_t *, double);
-extern int (*matmul_f64_i8_u8)(size_t, size_t, size_t, const double *, const int8_t *, uint8_t *, double);
-extern int (*matmul_f64_u8_f32)(size_t, size_t, size_t, const double *, const uint8_t *, float *, double);
-extern int (*matmul_f64_u8_f64)(size_t, size_t, size_t, const double *, const uint8_t *, double *, double);
-extern int (*matmul_f64_u8_i8)(size_t, size_t, size_t, const double *, const uint8_t *, int8_t *, double);
-extern int (*matmul_f64_u8_u8)(size_t, size_t, size_t, const double *, const uint8_t *, uint8_t *, double);
-extern int (*matmul_i8_f32_f32)(size_t, size_t, size_t, const int8_t *, const float *, float *, double);
-extern int (*matmul_i8_f32_f64)(size_t, size_t, size_t, const int8_t *, const float *, double *, double);
-extern int (*matmul_i8_f32_i8)(size_t, size_t, size_t, const int8_t *, const float *, int8_t *, double);
-extern int (*matmul_i8_f32_u8)(size_t, size_t, size_t, const int8_t *, const float *, uint8_t *, double);
-extern int (*matmul_i8_f64_f32)(size_t, size_t, size_t, const int8_t *, const double *, float *, double);
-extern int (*matmul_i8_f64_f64)(size_t, size_t, size_t, const int8_t *, const double *, double *, double);
-extern int (*matmul_i8_f64_i8)(size_t, size_t, size_t, const int8_t *, const double *, int8_t *, double);
-extern int (*matmul_i8_f64_u8)(size_t, size_t, size_t, const int8_t *, const double *, uint8_t *, double);
-extern int (*matmul_i8_i8_f32)(size_t, size_t, size_t, const int8_t *, const int8_t *, float *, double);
-extern int (*matmul_i8_i8_f64)(size_t, size_t, size_t, const int8_t *, const int8_t *, double *, double);
-extern int (*matmul_i8_i8_i8)(size_t, size_t, size_t, const int8_t *, const int8_t *, int8_t *, double);
-extern int (*matmul_i8_i8_u8)(size_t, size_t, size_t, const int8_t *, const int8_t *, uint8_t *, double);
-extern int (*matmul_i8_u8_f32)(size_t, size_t, size_t, const int8_t *, const uint8_t *, float *, double);
-extern int (*matmul_i8_u8_f64)(size_t, size_t, size_t, const int8_t *, const uint8_t *, double *, double);
-extern int (*matmul_i8_u8_i8)(size_t, size_t, size_t, const int8_t *, const uint8_t *, int8_t *, double);
-extern int (*matmul_i8_u8_u8)(size_t, size_t, size_t, const int8_t *, const uint8_t *, uint8_t *, double);
-extern int (*matmul_u8_f32_f32)(size_t, size_t, size_t, const uint8_t *, const float *, float *, double);
-extern int (*matmul_u8_f32_f64)(size_t, size_t, size_t, const uint8_t *, const float *, double *, double);
-extern int (*matmul_u8_f32_i8)(size_t, size_t, size_t, const uint8_t *, const float *, int8_t *, double);
-extern int (*matmul_u8_f32_u8)(size_t, size_t, size_t, const uint8_t *, const float *, uint8_t *, double);
-extern int (*matmul_u8_f64_f32)(size_t, size_t, size_t, const uint8_t *, const double *, float *, double);
-extern int (*matmul_u8_f64_f64)(size_t, size_t, size_t, const uint8_t *, const double *, double *, double);
-extern int (*matmul_u8_f64_i8)(size_t, size_t, size_t, const uint8_t *, const double *, int8_t *, double);
-extern int (*matmul_u8_f64_u8)(size_t, size_t, size_t, const uint8_t *, const double *, uint8_t *, double);
-extern int (*matmul_u8_i8_f32)(size_t, size_t, size_t, const uint8_t *, const int8_t *, float *, double);
-extern int (*matmul_u8_i8_f64)(size_t, size_t, size_t, const uint8_t *, const int8_t *, double *, double);
-extern int (*matmul_u8_i8_i8)(size_t, size_t, size_t, const uint8_t *, const int8_t *, int8_t *, double);
extern int (*matmul_u8_i8_u8)(size_t, size_t, size_t, const uint8_t *, const int8_t *, uint8_t *, double);
-extern int (*matmul_u8_u8_f32)(size_t, size_t, size_t, const uint8_t *, const uint8_t *, float *, double);
-extern int (*matmul_u8_u8_f64)(size_t, size_t, size_t, const uint8_t *, const uint8_t *, double *, double);
-extern int (*matmul_u8_u8_i8)(size_t, size_t, size_t, const uint8_t *, const uint8_t *, int8_t *, double);
-extern int (*matmul_u8_u8_u8)(size_t, size_t, size_t, const uint8_t *, const uint8_t *, uint8_t *, double);
-#define matmul(m, n, p, A, B, C, scale) \
- _Generic((A), \
- float: _Generic((B), \
- float: _Generic((C), \
- float: matmul_f32_f32_f32, \
- double: matmul_f32_f32_f64, \
- int8_t: matmul_f32_f32_i8, \
- uint8_t: matmul_f32_f32_u8), \
- double: _Generic((C), \
- float: matmul_f32_f64_f32, \
- double: matmul_f32_f64_f64, \
- int8_t: matmul_f32_f64_i8, \
- uint8_t: matmul_f32_f64_u8), \
- int8_t: _Generic((C), \
- float: matmul_f32_i8_f32, \
- double: matmul_f32_i8_f64, \
- int8_t: matmul_f32_i8_i8, \
- uint8_t: matmul_f32_i8_u8), \
- uint8_t: _Generic((C), \
- float: matmul_f32_u8_f32, \
- double: matmul_f32_u8_f64, \
- int8_t: matmul_f32_u8_i8, \
- uint8_t: matmul_f32_u8_u8)), \
- double: _Generic((B), \
- float: _Generic((C), \
- float: matmul_f64_f32_f32, \
- double: matmul_f64_f32_f64, \
- int8_t: matmul_f64_f32_i8, \
- uint8_t: matmul_f64_f32_u8), \
- double: _Generic((C), \
- float: matmul_f64_f64_f32, \
- double: matmul_f64_f64_f64, \
- int8_t: matmul_f64_f64_i8, \
- uint8_t: matmul_f64_f64_u8), \
- int8_t: _Generic((C), \
- float: matmul_f64_i8_f32, \
- double: matmul_f64_i8_f64, \
- int8_t: matmul_f64_i8_i8, \
- uint8_t: matmul_f64_i8_u8), \
- uint8_t: _Generic((C), \
- float: matmul_f64_u8_f32, \
- double: matmul_f64_u8_f64, \
- int8_t: matmul_f64_u8_i8, \
- uint8_t: matmul_f64_u8_u8)), \
- int8_t: _Generic((B), \
- float: _Generic((C), \
- float: matmul_i8_f32_f32, \
- double: matmul_i8_f32_f64, \
- int8_t: matmul_i8_f32_i8, \
- uint8_t: matmul_i8_f32_u8), \
- double: _Generic((C), \
- float: matmul_i8_f64_f32, \
- double: matmul_i8_f64_f64, \
- int8_t: matmul_i8_f64_i8, \
- uint8_t: matmul_i8_f64_u8), \
- int8_t: _Generic((C), \
- float: matmul_i8_i8_f32, \
- double: matmul_i8_i8_f64, \
- int8_t: matmul_i8_i8_i8, \
- uint8_t: matmul_i8_i8_u8), \
- uint8_t: _Generic((C), \
- float: matmul_i8_u8_f32, \
- double: matmul_i8_u8_f64, \
- int8_t: matmul_i8_u8_i8, \
- uint8_t: matmul_i8_u8_u8)), \
- uint8_t: _Generic((B), \
- float: _Generic((C), \
- float: matmul_u8_f32_f32, \
- double: matmul_u8_f32_f64, \
- int8_t: matmul_u8_f32_i8, \
- uint8_t: matmul_u8_f32_u8), \
- double: _Generic((C), \
- float: matmul_u8_f64_f32, \
- double: matmul_u8_f64_f64, \
- int8_t: matmul_u8_f64_i8, \
- uint8_t: matmul_u8_f64_u8), \
- int8_t: _Generic((C), \
- float: matmul_u8_i8_f32, \
- double: matmul_u8_i8_f64, \
- int8_t: matmul_u8_i8_i8, \
- uint8_t: matmul_u8_i8_u8), \
- uint8_t: _Generic((C), \
- float: matmul_u8_u8_f32, \
- double: matmul_u8_u8_f64, \
- int8_t: matmul_u8_u8_i8, \
- uint8_t: matmul_u8_u8_u8)))((m), (n), (p), (A), (B), (C), (scale))
+#define matmul(m, n, p, A, B, C, scale) \
+ _Generic((A), \
+ uint8_t *: _Generic((B), \
+ int8_t *: _Generic((C), \
+ uint8_t *: matmul_u8_i8_u8 \
+ ) \
+ ) \
+ )((m), (n), (p), (A), (B), (C), (scale))
#ifdef __cplusplus
}