commit 5e4fc7d0e0d87092d880d64949afe6dd13c68bd6
parent df550e3004136b431d34b051cde05616907d9f4b
Author: finwo <finwo@pm.me>
Date: Tue, 21 Apr 2026 18:33:05 +0200
Public api
Diffstat:
| M | src/matmul.h | | | 64 | ++++++++++++++++++++++++++++++++++++++++++++++------------------ |
1 file changed, 46 insertions(+), 18 deletions(-)
diff --git a/src/matmul.h b/src/matmul.h
@@ -48,25 +48,53 @@ extern "C" {
extern int (*matmul_u8_i8_u8)(size_t, size_t, size_t, const uint8_t *, const int8_t *, uint8_t *, double);
extern int (*matmul_f32_f32_f32)(size_t, size_t, size_t, const float *, const float *, float *, double);
extern int (*matmul_f64_f64_f64)(size_t, size_t, size_t, const double *, const double *, double *, double);
+extern int (*matmul_not_implemented)(size_t, size_t, size_t, const void *, const void *, void *, double);
-#define matmul(m, n, p, A, B, C, scale) \
- _Generic((A), \
- uint8_t *: _Generic((B), \
- int8_t *: _Generic((C), \
- uint8_t *: matmul_u8_i8_u8 \
- ) \
- ), \
- float *: _Generic((B), \
- float *: _Generic((C), \
- float *: matmul_f32_f32_f32 \
- ) \
- ), \
- double *: _Generic((B), \
- double *: _Generic((C), \
- double *: matmul_f64_f64_f64 \
- ) \
- ) \
- )((m), (n), (p), (A), (B), (C), (scale))
+#define matmul(m, n, p, A, B, C, scale) \
+ _Generic((A), \
+ const uint8_t *: _Generic((B), \
+ const int8_t *: _Generic((C), uint8_t *: matmul_u8_i8_u8), \
+ int8_t *: _Generic((C), uint8_t *: matmul_not_implemented), \
+ const float *: _Generic((C), uint8_t *: matmul_not_implemented), \
+ float *: _Generic((C), uint8_t *: matmul_not_implemented), \
+ const double *: _Generic((C), uint8_t *: matmul_not_implemented), \
+ double *: _Generic((C), uint8_t *: matmul_not_implemented)), \
+ uint8_t *: _Generic((B), \
+ const int8_t *: _Generic((C), uint8_t *: matmul_u8_i8_u8), \
+ int8_t *: _Generic((C), uint8_t *: matmul_not_implemented), \
+ const float *: _Generic((C), uint8_t *: matmul_not_implemented), \
+ float *: _Generic((C), uint8_t *: matmul_not_implemented), \
+ const double *: _Generic((C), uint8_t *: matmul_not_implemented), \
+ double *: _Generic((C), uint8_t *: matmul_not_implemented)), \
+ const float *: _Generic((B), \
+ const int8_t *: _Generic((C), uint8_t *: matmul_not_implemented), \
+ int8_t *: _Generic((C), uint8_t *: matmul_not_implemented), \
+ const float *: _Generic((C), float *: matmul_f32_f32_f32), \
+ float *: _Generic((C), float *: matmul_not_implemented), \
+ const double *: _Generic((C), float *: matmul_not_implemented), \
+ double *: _Generic((C), float *: matmul_not_implemented)), \
+ float *: _Generic((B), \
+ const int8_t *: _Generic((C), uint8_t *: matmul_not_implemented), \
+ int8_t *: _Generic((C), uint8_t *: matmul_not_implemented), \
+ const float *: _Generic((C), float *: matmul_f32_f32_f32), \
+ float *: _Generic((C), float *: matmul_not_implemented), \
+ const double *: _Generic((C), float *: matmul_not_implemented), \
+ double *: _Generic((C), float *: matmul_not_implemented)), \
+ const double *: _Generic((B), \
+ const int8_t *: _Generic((C), uint8_t *: matmul_not_implemented), \
+ int8_t *: _Generic((C), uint8_t *: matmul_not_implemented), \
+ const float *: _Generic((C), float *: matmul_not_implemented), \
+ float *: _Generic((C), float *: matmul_not_implemented), \
+ const double *: _Generic((C), double *: matmul_f64_f64_f64), \
+ double *: _Generic((C), double *: matmul_not_implemented)), \
+ double *: _Generic((B), \
+ const int8_t *: _Generic((C), uint8_t *: matmul_not_implemented), \
+ int8_t *: _Generic((C), uint8_t *: matmul_not_implemented), \
+ const float *: _Generic((C), float *: matmul_not_implemented), \
+ float *: _Generic((C), float *: matmul_not_implemented), \
+ const double *: _Generic((C), double *: matmul_f64_f64_f64), \
+ double *: _Generic((C), double *: matmul_not_implemented)), \
+ void *: matmul_not_implemented)((m), (n), (p), (A), (B), (C), (scale))
#ifdef __cplusplus
}