matmul.c

Matrix multiplication helper library
git clone git://git.finwo.net/lib/matmul.c
Log | Files | Refs | README | LICENSE

commit 11f476190e1bacace411948da4e194c3f6080a93
parent 8f294137c31dbe5d5e59714450bc9eee49988b63
Author: finwo <finwo@pm.me>
Date:   Sat, 18 Apr 2026 04:36:31 +0200

Hard-define the desired public api

Diffstat:
MMakefile | 18++++++++++++++----
Msrc/matmul.h | 157++++---------------------------------------------------------------------------
2 files changed, 22 insertions(+), 153 deletions(-)

diff --git a/Makefile b/Makefile @@ -9,17 +9,26 @@ CFLAGS?=-Wall -std=c99 include lib/.dep/config.mk -CGLAGS+=$(INCLUDES) +CFLAGS+=-D_DEFAULT_SOURCE -mfma -mavx2 -mavxvnni -mavx512f -mavx512vnni -mavx512bw -fopenmp -O3 + +CFLAGS+=$(INCLUDES) OBJ=$(SRC:.c=.o) BIN=\ - test_matmul + test_matmul \ + benchmark default: $(BIN) -$(BIN): $(OBJ) test/$(BIN:=.o) - $(CC) $(CFLAGS) $(OBJ) test/$@.o -o $@ +test/%.o: test/%.c + $(CC) $(CFLAGS) -c $< -o $@ + +test_matmul: $(OBJ) test/test_matmul.o + $(CC) $(CFLAGS) $(OBJ) test/test_matmul.o -o $@ -lrt -fopenmp + +benchmark: $(OBJ) test/benchmark.o + $(CC) $(CFLAGS) $(OBJ) test/benchmark.o -o $@ -lrt -fopenmp .PHONY: clean clean: @@ -32,3 +41,4 @@ clean: .PHONY: format format: $(FIND) src/ -type f \( -name '*.c' -o -name '*.h' \) -exec clang-format -i {} + + $(FIND) test/ -type f \( -name '*.c' -o -name '*.h' \) -exec clang-format -i {} + diff --git a/src/matmul.h b/src/matmul.h @@ -45,157 +45,16 @@ extern "C" { #endif -extern int (*matmul_f32_f32_f32)(size_t, size_t, size_t, const float *, const float *, float *, double); -extern int (*matmul_f32_f32_f64)(size_t, size_t, size_t, const float *, const float *, double *, double); -extern int (*matmul_f32_f32_i8)(size_t, size_t, size_t, const float *, const float *, int8_t *, double); -extern int (*matmul_f32_f32_u8)(size_t, size_t, size_t, const float *, const float *, uint8_t *, double); -extern int (*matmul_f32_f64_f32)(size_t, size_t, size_t, const float *, const double *, float *, double); -extern int (*matmul_f32_f64_f64)(size_t, size_t, size_t, const float *, const double *, double *, double); -extern int (*matmul_f32_f64_i8)(size_t, size_t, size_t, const float *, const double *, int8_t *, double); -extern int (*matmul_f32_f64_u8)(size_t, size_t, size_t, const float *, const double *, uint8_t *, double); -extern int (*matmul_f32_i8_f32)(size_t, size_t, size_t, const float *, const int8_t *, float *, double); -extern int (*matmul_f32_i8_f64)(size_t, size_t, size_t, const float *, const int8_t *, double *, double); -extern int (*matmul_f32_i8_i8)(size_t, size_t, size_t, const float *, const int8_t *, int8_t *, double); -extern int (*matmul_f32_i8_u8)(size_t, size_t, size_t, const float *, const int8_t *, uint8_t *, double); -extern int (*matmul_f32_u8_f32)(size_t, size_t, size_t, const float *, const uint8_t *, float *, double); -extern int (*matmul_f32_u8_f64)(size_t, size_t, size_t, const float *, const uint8_t *, double *, double); -extern int (*matmul_f32_u8_i8)(size_t, size_t, size_t, const float *, const uint8_t *, int8_t *, double); -extern int (*matmul_f32_u8_u8)(size_t, size_t, size_t, const float *, const uint8_t *, uint8_t *, double); -extern int (*matmul_f64_f32_f32)(size_t, size_t, size_t, const double *, const float *, float *, double); -extern int (*matmul_f64_f32_f64)(size_t, size_t, size_t, const double *, const float *, double *, double); -extern int (*matmul_f64_f32_i8)(size_t, size_t, size_t, const double *, const float *, int8_t *, double); -extern int (*matmul_f64_f32_u8)(size_t, size_t, size_t, const double *, const float *, uint8_t *, double); -extern int (*matmul_f64_f64_f32)(size_t, size_t, size_t, const double *, const double *, float *, double); -extern int (*matmul_f64_f64_f64)(size_t, size_t, size_t, const double *, const double *, double *, double); -extern int (*matmul_f64_f64_i8)(size_t, size_t, size_t, const double *, const double *, int8_t *, double); -extern int (*matmul_f64_f64_u8)(size_t, size_t, size_t, const double *, const double *, uint8_t *, double); -extern int (*matmul_f64_i8_f32)(size_t, size_t, size_t, const double *, const int8_t *, float *, double); -extern int (*matmul_f64_i8_f64)(size_t, size_t, size_t, const double *, const int8_t *, double *, double); -extern int (*matmul_f64_i8_i8)(size_t, size_t, size_t, const double *, const int8_t *, int8_t *, double); -extern int (*matmul_f64_i8_u8)(size_t, size_t, size_t, const double *, const int8_t *, uint8_t *, double); -extern int (*matmul_f64_u8_f32)(size_t, size_t, size_t, const double *, const uint8_t *, float *, double); -extern int (*matmul_f64_u8_f64)(size_t, size_t, size_t, const double *, const uint8_t *, double *, double); -extern int (*matmul_f64_u8_i8)(size_t, size_t, size_t, const double *, const uint8_t *, int8_t *, double); -extern int (*matmul_f64_u8_u8)(size_t, size_t, size_t, const double *, const uint8_t *, uint8_t *, double); -extern int (*matmul_i8_f32_f32)(size_t, size_t, size_t, const int8_t *, const float *, float *, double); -extern int (*matmul_i8_f32_f64)(size_t, size_t, size_t, const int8_t *, const float *, double *, double); -extern int (*matmul_i8_f32_i8)(size_t, size_t, size_t, const int8_t *, const float *, int8_t *, double); -extern int (*matmul_i8_f32_u8)(size_t, size_t, size_t, const int8_t *, const float *, uint8_t *, double); -extern int (*matmul_i8_f64_f32)(size_t, size_t, size_t, const int8_t *, const double *, float *, double); -extern int (*matmul_i8_f64_f64)(size_t, size_t, size_t, const int8_t *, const double *, double *, double); -extern int (*matmul_i8_f64_i8)(size_t, size_t, size_t, const int8_t *, const double *, int8_t *, double); -extern int (*matmul_i8_f64_u8)(size_t, size_t, size_t, const int8_t *, const double *, uint8_t *, double); -extern int (*matmul_i8_i8_f32)(size_t, size_t, size_t, const int8_t *, const int8_t *, float *, double); -extern int (*matmul_i8_i8_f64)(size_t, size_t, size_t, const int8_t *, const int8_t *, double *, double); -extern int (*matmul_i8_i8_i8)(size_t, size_t, size_t, const int8_t *, const int8_t *, int8_t *, double); -extern int (*matmul_i8_i8_u8)(size_t, size_t, size_t, const int8_t *, const int8_t *, uint8_t *, double); -extern int (*matmul_i8_u8_f32)(size_t, size_t, size_t, const int8_t *, const uint8_t *, float *, double); -extern int (*matmul_i8_u8_f64)(size_t, size_t, size_t, const int8_t *, const uint8_t *, double *, double); -extern int (*matmul_i8_u8_i8)(size_t, size_t, size_t, const int8_t *, const uint8_t *, int8_t *, double); -extern int (*matmul_i8_u8_u8)(size_t, size_t, size_t, const int8_t *, const uint8_t *, uint8_t *, double); -extern int (*matmul_u8_f32_f32)(size_t, size_t, size_t, const uint8_t *, const float *, float *, double); -extern int (*matmul_u8_f32_f64)(size_t, size_t, size_t, const uint8_t *, const float *, double *, double); -extern int (*matmul_u8_f32_i8)(size_t, size_t, size_t, const uint8_t *, const float *, int8_t *, double); -extern int (*matmul_u8_f32_u8)(size_t, size_t, size_t, const uint8_t *, const float *, uint8_t *, double); -extern int (*matmul_u8_f64_f32)(size_t, size_t, size_t, const uint8_t *, const double *, float *, double); -extern int (*matmul_u8_f64_f64)(size_t, size_t, size_t, const uint8_t *, const double *, double *, double); -extern int (*matmul_u8_f64_i8)(size_t, size_t, size_t, const uint8_t *, const double *, int8_t *, double); -extern int (*matmul_u8_f64_u8)(size_t, size_t, size_t, const uint8_t *, const double *, uint8_t *, double); -extern int (*matmul_u8_i8_f32)(size_t, size_t, size_t, const uint8_t *, const int8_t *, float *, double); -extern int (*matmul_u8_i8_f64)(size_t, size_t, size_t, const uint8_t *, const int8_t *, double *, double); -extern int (*matmul_u8_i8_i8)(size_t, size_t, size_t, const uint8_t *, const int8_t *, int8_t *, double); extern int (*matmul_u8_i8_u8)(size_t, size_t, size_t, const uint8_t *, const int8_t *, uint8_t *, double); -extern int (*matmul_u8_u8_f32)(size_t, size_t, size_t, const uint8_t *, const uint8_t *, float *, double); -extern int (*matmul_u8_u8_f64)(size_t, size_t, size_t, const uint8_t *, const uint8_t *, double *, double); -extern int (*matmul_u8_u8_i8)(size_t, size_t, size_t, const uint8_t *, const uint8_t *, int8_t *, double); -extern int (*matmul_u8_u8_u8)(size_t, size_t, size_t, const uint8_t *, const uint8_t *, uint8_t *, double); -#define matmul(m, n, p, A, B, C, scale) \ - _Generic((A), \ - float: _Generic((B), \ - float: _Generic((C), \ - float: matmul_f32_f32_f32, \ - double: matmul_f32_f32_f64, \ - int8_t: matmul_f32_f32_i8, \ - uint8_t: matmul_f32_f32_u8), \ - double: _Generic((C), \ - float: matmul_f32_f64_f32, \ - double: matmul_f32_f64_f64, \ - int8_t: matmul_f32_f64_i8, \ - uint8_t: matmul_f32_f64_u8), \ - int8_t: _Generic((C), \ - float: matmul_f32_i8_f32, \ - double: matmul_f32_i8_f64, \ - int8_t: matmul_f32_i8_i8, \ - uint8_t: matmul_f32_i8_u8), \ - uint8_t: _Generic((C), \ - float: matmul_f32_u8_f32, \ - double: matmul_f32_u8_f64, \ - int8_t: matmul_f32_u8_i8, \ - uint8_t: matmul_f32_u8_u8)), \ - double: _Generic((B), \ - float: _Generic((C), \ - float: matmul_f64_f32_f32, \ - double: matmul_f64_f32_f64, \ - int8_t: matmul_f64_f32_i8, \ - uint8_t: matmul_f64_f32_u8), \ - double: _Generic((C), \ - float: matmul_f64_f64_f32, \ - double: matmul_f64_f64_f64, \ - int8_t: matmul_f64_f64_i8, \ - uint8_t: matmul_f64_f64_u8), \ - int8_t: _Generic((C), \ - float: matmul_f64_i8_f32, \ - double: matmul_f64_i8_f64, \ - int8_t: matmul_f64_i8_i8, \ - uint8_t: matmul_f64_i8_u8), \ - uint8_t: _Generic((C), \ - float: matmul_f64_u8_f32, \ - double: matmul_f64_u8_f64, \ - int8_t: matmul_f64_u8_i8, \ - uint8_t: matmul_f64_u8_u8)), \ - int8_t: _Generic((B), \ - float: _Generic((C), \ - float: matmul_i8_f32_f32, \ - double: matmul_i8_f32_f64, \ - int8_t: matmul_i8_f32_i8, \ - uint8_t: matmul_i8_f32_u8), \ - double: _Generic((C), \ - float: matmul_i8_f64_f32, \ - double: matmul_i8_f64_f64, \ - int8_t: matmul_i8_f64_i8, \ - uint8_t: matmul_i8_f64_u8), \ - int8_t: _Generic((C), \ - float: matmul_i8_i8_f32, \ - double: matmul_i8_i8_f64, \ - int8_t: matmul_i8_i8_i8, \ - uint8_t: matmul_i8_i8_u8), \ - uint8_t: _Generic((C), \ - float: matmul_i8_u8_f32, \ - double: matmul_i8_u8_f64, \ - int8_t: matmul_i8_u8_i8, \ - uint8_t: matmul_i8_u8_u8)), \ - uint8_t: _Generic((B), \ - float: _Generic((C), \ - float: matmul_u8_f32_f32, \ - double: matmul_u8_f32_f64, \ - int8_t: matmul_u8_f32_i8, \ - uint8_t: matmul_u8_f32_u8), \ - double: _Generic((C), \ - float: matmul_u8_f64_f32, \ - double: matmul_u8_f64_f64, \ - int8_t: matmul_u8_f64_i8, \ - uint8_t: matmul_u8_f64_u8), \ - int8_t: _Generic((C), \ - float: matmul_u8_i8_f32, \ - double: matmul_u8_i8_f64, \ - int8_t: matmul_u8_i8_i8, \ - uint8_t: matmul_u8_i8_u8), \ - uint8_t: _Generic((C), \ - float: matmul_u8_u8_f32, \ - double: matmul_u8_u8_f64, \ - int8_t: matmul_u8_u8_i8, \ - uint8_t: matmul_u8_u8_u8)))((m), (n), (p), (A), (B), (C), (scale)) +#define matmul(m, n, p, A, B, C, scale) \ + _Generic((A), \ + uint8_t *: _Generic((B), \ + int8_t *: _Generic((C), \ + uint8_t *: matmul_u8_i8_u8 \ + ) \ + ) \ + )((m), (n), (p), (A), (B), (C), (scale)) #ifdef __cplusplus }