commit a82259b0db04022dd7c215fa0b33aab9315cda24
parent b928a1f946e692c012b0eb936520c2bd960b058c
Author: finwo <finwo@pm.me>
Date: Thu, 23 Apr 2026 14:31:23 +0200
Simplified main macro
Diffstat:
| M | README.md | | | 63 | +++++++++++++++++++++++++++++++++++++++++---------------------- |
| M | src/matmul.h | | | 51 | +++++++++++---------------------------------------- |
2 files changed, 52 insertions(+), 62 deletions(-)
diff --git a/README.md b/README.md
@@ -2,8 +2,10 @@
A lightweight, type-safe matrix multiplication library with runtime dispatch and SIMD acceleration.
-**Current Implementation:** `uint8_t` × `int8_t` → `uint8_t`
-*(additional type combinations planned)*
+**Current Implementations:**
+- `uint8_t` × `int8_t` → `uint8_t`
+- `float` × `float` → `float`
+- `double` × `double` → `double`
## Installation
@@ -25,6 +27,7 @@ Alternatively, you can include the [matmul.c](src/matmul.c) and [matmul.h](src/m
## Quick Start
+### Quantized Multiplication (u8 × i8 → u8)
```c
#include "finwo/matmul.h"
#include <stdio.h>
@@ -40,7 +43,6 @@ int main() {
// Multiply A(2x3) * B(3x2) = C(2x2)
matmul(2, 3, 2, A, B, C, 0.0); // scale=0: no scaling
- // C should be [1, 2, 4, 5]
printf("C[0] = %u, C[1] = %u, C[2] = %u, C[3] = %u\n",
C[0], C[1], C[2], C[3]);
@@ -48,6 +50,23 @@ int main() {
}
```
+### Floating Point Multiplication (f32 × f32 → f32)
+```c
+#include "finwo/matmul.h"
+#include <stdio.h>
+
+int main() {
+ float A[4] = {1.0f, 2.0f, 3.0f, 4.0f}; // 2x2
+ float B[4] = {5.0f, 6.0f, 7.0f, 8.0f}; // 2x2
+ float C[4];
+
+ matmul(2, 2, 2, A, B, C, 1.0);
+
+ printf("C[0] = %f\n", C[0]);
+ return 0;
+}
+```
+
Compile with: `cc -o example example.c -lm -fopenmp`
## API Reference
@@ -61,16 +80,25 @@ matmul(m, n, p, A, B, C, scale);
### Direct Function Calls
```c
-// Currently implemented type combination (u8 × i8 → u8)
+// Implemented type combinations
int matmul_u8_i8_u8(size_t m, size_t n, size_t p,
const uint8_t *A, const int8_t *B,
uint8_t *C, double scale);
-// Scalar and SIMD variants
+int matmul_f32_f32_f32(size_t m, size_t n, size_t p,
+ const float *A, const float *B,
+ float *C, double scale);
+
+int matmul_f64_f64_f64(size_t m, size_t n, size_t p,
+ const double *A, const double *B,
+ double *C, double scale);
+
+// Scalar and SIMD variants (internal/specialized)
int matmul_scalar_u8_i8_u8(...);
-int matmul_avx2_u8_i8_u8(...);
-int matmul_avx512_u8_i8_u8(...);
-int matmul_avxvnni_u8_i8_u8(...);
+int matmul_avx512vnni_u8_i8_u8(...);
+int matmul_scalar_f32_f32_f32(...);
+int matmul_avx2_f32_f32_f32(...);
+int matmul_avx512_f32_f32_f32(...);
```
### Type Naming Conventions
@@ -114,18 +142,6 @@ make clean
Requires a C11 compiler (gcc, clang, MSVC) with OpenMP support.
-## Performance
-
-On an AMD Ryzen 7 9800X3D (Zen 5, 16 cores), with AVX2/AVX-VNNI/AVX512/AVX512-VNNI enabled:
-
-| Matrix Size | Time (ms) | GFLOPS |
-|-------------|-----------|---------|
-| 1024×1024 | ~2.4 | ~880 |
-| 4096×4096 | ~173 | ~795 |
-| 16384×16384 | ~12970 | ~678 |
-
-The 16384×16384 case exceeds L3 cache (~768MB total), so performance is memory-bound. For optimal performance use matrices that fit in L3 (≤4096 on most CPUs).
-
## Testing
The library includes unit tests verifying correctness across all implementations:
@@ -135,8 +151,11 @@ The library includes unit tests verifying correctness across all implementations
## Implementation Notes
-- **Automatic dispatch**: The first call runtime-detects CPU features and selects the optimal implementation
-- **Dispatch priority**: AVX512-VNNI → AVX512 → AVX-VNNI → AVX2 → Scalar
+- **Automatic dispatch**: The first call runtime-detects CPU features and selects the optimal implementation for the given types
+- **Dispatch priority**:
+ - `u8_i8_u8`: AVX512-VNNI → Scalar
+ - `f32_f32_f32`: AVX512 → AVX2 → Scalar
+ - `f64_f64_f64`: AVX512 → AVX2 → Scalar
- **Parallelization**: OpenMP `parallel for` with `static` scheduling across row blocks
- **Tiling**: Blocking factors tuned for L1/L2 cache (ib=32/64, jb=64, kb=32/64 depending on SIMD width)
diff --git a/src/matmul.h b/src/matmul.h
@@ -50,48 +50,19 @@ extern "C" {
#define __matmul_TYPE_f32 float
#define __matmul_TYPE_f64 double
-#define __matmul_EXTERN(__MATMUL_EXT_A, __MATMUL_EXT_B, __MATMUL_EXT_C) extern int (*matmul_##__MATMUL_EXT_A##_##__MATMUL_EXT_B##_##__MATMUL_EXT_C)(size_t, size_t, size_t, const __matmul_TYPE_##__MATMUL_EXT_A *, const __matmul_TYPE_##__MATMUL_EXT_B *, __matmul_TYPE_##__MATMUL_EXT_C *, double);
+extern int matmul_not_implemented(size_t m, size_t n, size_t p, void *A, void *B, void *C, double scale);
-#define __matmul_EXTERN_C(__MATMUL_EXT_A, __MATMUL_EXT_B) \
- __matmul_EXTERN(__MATMUL_EXT_A, __MATMUL_EXT_B, u8 ) \
- __matmul_EXTERN(__MATMUL_EXT_A, __MATMUL_EXT_B, f32) \
- __matmul_EXTERN(__MATMUL_EXT_A, __MATMUL_EXT_B, f64)
+extern int (*matmul_u8_i8_u8 )(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, uint8_t *C, double scale);
+extern int (*matmul_f32_f32_f32)(size_t m, size_t n, size_t p, const float *A, const float *B, float *C, double scale);
+extern int (*matmul_f64_f64_f64)(size_t m, size_t n, size_t p, const double *A, const double *B, double *C, double scale);
-#define __matmul_EXTERN_B(__MATMUL_EXT_A) \
- __matmul_EXTERN_C(__MATMUL_EXT_A, i8 ) \
- __matmul_EXTERN_C(__MATMUL_EXT_A, f32 ) \
- __matmul_EXTERN_C(__MATMUL_EXT_A, f64 )
-
-__matmul_EXTERN_B(u8 )
-__matmul_EXTERN_B(f32)
-__matmul_EXTERN_B(f64)
-
-#define __matmul_C(__matmul_arg_type_a,__matmul_arg_type_b) \
- _Generic((__MATMUL_ARG_C), \
- uint8_t *: matmul_##__matmul_arg_type_a##_##__matmul_arg_type_b##_u8, \
- float *: matmul_##__matmul_arg_type_a##_##__matmul_arg_type_b##_f32, \
- double *: matmul_##__matmul_arg_type_a##_##__matmul_arg_type_b##_f64 \
- )
-
-#define __matmul_B(__matmul_arg_type_a) \
- _Generic((__MATMUL_ARG_B), \
- int8_t *: __matmul_C(__matmul_arg_type_a,i8) \
- const int8_t *: __matmul_C(__matmul_arg_type_a,i8) \
- float *: __matmul_C(__matmul_arg_type_a,f32) \
- const float *: __matmul_C(__matmul_arg_type_a,f32) \
- double *: __matmul_C(__matmul_arg_type_a,f64) \
- const double *: __matmul_C(__matmul_arg_type_a,f64) \
- )
-
-#define matmul(m,n,p,__MATMUL_ARG_A,__MATMUL_ARG_B,__MATMUL_ARG_C,scale) \
- _Generic((__MATMUL_ARG_A), \
- uint8_t *: __matmul_B(u8) \
- const uint8_t *: __matmul_B(u8) \
- float *: __matmul_B(f32) \
- const float *: __matmul_B(f32) \
- double *: __matmul_B(f64) \
- const double *: __matmul_B(f64) \
- )(m,n,p,__MATMUL_ARG_A,__MATMUL_ARG_B,__MATMUL_ARG_C,scale)
+#define matmul(m,n,p,A,B,C) \
+ _Generic((void (*)(__typeof(A),__typeof(B),__typeof(C)))NULL, \
+ default: matmul_not_implemented, \
+ void (*)(uint8_t *, int8_t *, uint8_t *): matmul_u8_i8_u8 \
+ void (*)(float *, float *, float *): matmul_f32_f32_f32 \
+ void (*)(double *, double *, double *): matmul_f64_f64_f64 \
+ )(m,n,p,A,B,C,scale)
#ifdef __cplusplus
}