matmul.c

Matrix multiplication helper library
git clone git://git.finwo.net/lib/matmul.c
Log | Files | Refs | README | LICENSE

commit a82259b0db04022dd7c215fa0b33aab9315cda24
parent b928a1f946e692c012b0eb936520c2bd960b058c
Author: finwo <finwo@pm.me>
Date:   Thu, 23 Apr 2026 14:31:23 +0200

Simplified main macro

Diffstat:
MREADME.md | 63+++++++++++++++++++++++++++++++++++++++++----------------------
Msrc/matmul.h | 51+++++++++++----------------------------------------
2 files changed, 52 insertions(+), 62 deletions(-)

diff --git a/README.md b/README.md @@ -2,8 +2,10 @@ A lightweight, type-safe matrix multiplication library with runtime dispatch and SIMD acceleration. -**Current Implementation:** `uint8_t` × `int8_t` → `uint8_t` -*(additional type combinations planned)* +**Current Implementations:** +- `uint8_t` × `int8_t` → `uint8_t` +- `float` × `float` → `float` +- `double` × `double` → `double` ## Installation @@ -25,6 +27,7 @@ Alternatively, you can include the [matmul.c](src/matmul.c) and [matmul.h](src/m ## Quick Start +### Quantized Multiplication (u8 × i8 → u8) ```c #include "finwo/matmul.h" #include <stdio.h> @@ -40,7 +43,6 @@ int main() { // Multiply A(2x3) * B(3x2) = C(2x2) matmul(2, 3, 2, A, B, C, 0.0); // scale=0: no scaling - // C should be [1, 2, 4, 5] printf("C[0] = %u, C[1] = %u, C[2] = %u, C[3] = %u\n", C[0], C[1], C[2], C[3]); @@ -48,6 +50,23 @@ int main() { } ``` +### Floating Point Multiplication (f32 × f32 → f32) +```c +#include "finwo/matmul.h" +#include <stdio.h> + +int main() { + float A[4] = {1.0f, 2.0f, 3.0f, 4.0f}; // 2x2 + float B[4] = {5.0f, 6.0f, 7.0f, 8.0f}; // 2x2 + float C[4]; + + matmul(2, 2, 2, A, B, C, 1.0); + + printf("C[0] = %f\n", C[0]); + return 0; +} +``` + Compile with: `cc -o example example.c -lm -fopenmp` ## API Reference @@ -61,16 +80,25 @@ matmul(m, n, p, A, B, C, scale); ### Direct Function Calls ```c -// Currently implemented type combination (u8 × i8 → u8) +// Implemented type combinations int matmul_u8_i8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, uint8_t *C, double scale); -// Scalar and SIMD variants +int matmul_f32_f32_f32(size_t m, size_t n, size_t p, + const float *A, const float *B, + float *C, double scale); + +int matmul_f64_f64_f64(size_t m, size_t n, size_t p, + const double *A, const double *B, + double *C, double scale); + +// Scalar and SIMD variants (internal/specialized) int matmul_scalar_u8_i8_u8(...); -int matmul_avx2_u8_i8_u8(...); -int matmul_avx512_u8_i8_u8(...); -int matmul_avxvnni_u8_i8_u8(...); +int matmul_avx512vnni_u8_i8_u8(...); +int matmul_scalar_f32_f32_f32(...); +int matmul_avx2_f32_f32_f32(...); +int matmul_avx512_f32_f32_f32(...); ``` ### Type Naming Conventions @@ -114,18 +142,6 @@ make clean Requires a C11 compiler (gcc, clang, MSVC) with OpenMP support. -## Performance - -On an AMD Ryzen 7 9800X3D (Zen 5, 16 cores), with AVX2/AVX-VNNI/AVX512/AVX512-VNNI enabled: - -| Matrix Size | Time (ms) | GFLOPS | -|-------------|-----------|---------| -| 1024×1024 | ~2.4 | ~880 | -| 4096×4096 | ~173 | ~795 | -| 16384×16384 | ~12970 | ~678 | - -The 16384×16384 case exceeds L3 cache (~768MB total), so performance is memory-bound. For optimal performance use matrices that fit in L3 (≤4096 on most CPUs). - ## Testing The library includes unit tests verifying correctness across all implementations: @@ -135,8 +151,11 @@ The library includes unit tests verifying correctness across all implementations ## Implementation Notes -- **Automatic dispatch**: The first call runtime-detects CPU features and selects the optimal implementation -- **Dispatch priority**: AVX512-VNNI → AVX512 → AVX-VNNI → AVX2 → Scalar +- **Automatic dispatch**: The first call runtime-detects CPU features and selects the optimal implementation for the given types +- **Dispatch priority**: + - `u8_i8_u8`: AVX512-VNNI → Scalar + - `f32_f32_f32`: AVX512 → AVX2 → Scalar + - `f64_f64_f64`: AVX512 → AVX2 → Scalar - **Parallelization**: OpenMP `parallel for` with `static` scheduling across row blocks - **Tiling**: Blocking factors tuned for L1/L2 cache (ib=32/64, jb=64, kb=32/64 depending on SIMD width) diff --git a/src/matmul.h b/src/matmul.h @@ -50,48 +50,19 @@ extern "C" { #define __matmul_TYPE_f32 float #define __matmul_TYPE_f64 double -#define __matmul_EXTERN(__MATMUL_EXT_A, __MATMUL_EXT_B, __MATMUL_EXT_C) extern int (*matmul_##__MATMUL_EXT_A##_##__MATMUL_EXT_B##_##__MATMUL_EXT_C)(size_t, size_t, size_t, const __matmul_TYPE_##__MATMUL_EXT_A *, const __matmul_TYPE_##__MATMUL_EXT_B *, __matmul_TYPE_##__MATMUL_EXT_C *, double); +extern int matmul_not_implemented(size_t m, size_t n, size_t p, void *A, void *B, void *C, double scale); -#define __matmul_EXTERN_C(__MATMUL_EXT_A, __MATMUL_EXT_B) \ - __matmul_EXTERN(__MATMUL_EXT_A, __MATMUL_EXT_B, u8 ) \ - __matmul_EXTERN(__MATMUL_EXT_A, __MATMUL_EXT_B, f32) \ - __matmul_EXTERN(__MATMUL_EXT_A, __MATMUL_EXT_B, f64) +extern int (*matmul_u8_i8_u8 )(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, uint8_t *C, double scale); +extern int (*matmul_f32_f32_f32)(size_t m, size_t n, size_t p, const float *A, const float *B, float *C, double scale); +extern int (*matmul_f64_f64_f64)(size_t m, size_t n, size_t p, const double *A, const double *B, double *C, double scale); -#define __matmul_EXTERN_B(__MATMUL_EXT_A) \ - __matmul_EXTERN_C(__MATMUL_EXT_A, i8 ) \ - __matmul_EXTERN_C(__MATMUL_EXT_A, f32 ) \ - __matmul_EXTERN_C(__MATMUL_EXT_A, f64 ) - -__matmul_EXTERN_B(u8 ) -__matmul_EXTERN_B(f32) -__matmul_EXTERN_B(f64) - -#define __matmul_C(__matmul_arg_type_a,__matmul_arg_type_b) \ - _Generic((__MATMUL_ARG_C), \ - uint8_t *: matmul_##__matmul_arg_type_a##_##__matmul_arg_type_b##_u8, \ - float *: matmul_##__matmul_arg_type_a##_##__matmul_arg_type_b##_f32, \ - double *: matmul_##__matmul_arg_type_a##_##__matmul_arg_type_b##_f64 \ - ) - -#define __matmul_B(__matmul_arg_type_a) \ - _Generic((__MATMUL_ARG_B), \ - int8_t *: __matmul_C(__matmul_arg_type_a,i8) \ - const int8_t *: __matmul_C(__matmul_arg_type_a,i8) \ - float *: __matmul_C(__matmul_arg_type_a,f32) \ - const float *: __matmul_C(__matmul_arg_type_a,f32) \ - double *: __matmul_C(__matmul_arg_type_a,f64) \ - const double *: __matmul_C(__matmul_arg_type_a,f64) \ - ) - -#define matmul(m,n,p,__MATMUL_ARG_A,__MATMUL_ARG_B,__MATMUL_ARG_C,scale) \ - _Generic((__MATMUL_ARG_A), \ - uint8_t *: __matmul_B(u8) \ - const uint8_t *: __matmul_B(u8) \ - float *: __matmul_B(f32) \ - const float *: __matmul_B(f32) \ - double *: __matmul_B(f64) \ - const double *: __matmul_B(f64) \ - )(m,n,p,__MATMUL_ARG_A,__MATMUL_ARG_B,__MATMUL_ARG_C,scale) +#define matmul(m,n,p,A,B,C) \ + _Generic((void (*)(__typeof(A),__typeof(B),__typeof(C)))NULL, \ + default: matmul_not_implemented, \ + void (*)(uint8_t *, int8_t *, uint8_t *): matmul_u8_i8_u8 \ + void (*)(float *, float *, float *): matmul_f32_f32_f32 \ + void (*)(double *, double *, double *): matmul_f64_f64_f64 \ + )(m,n,p,A,B,C,scale) #ifdef __cplusplus }