commit 4615e6a5d18b3ee05831fab834628d51c011a6d9
parent 11f476190e1bacace411948da4e194c3f6080a93
Author: finwo <finwo@pm.me>
Date: Sat, 18 Apr 2026 18:45:29 +0200
Functional
Diffstat:
| M | .gitignore | | | 1 | + |
| M | README.md | | | 89 | ++++++++++++++++++++++++++++++++++++++++++++++--------------------------------- |
| M | src/matmul.c | | | 3468 | ++++--------------------------------------------------------------------------- |
| M | test/benchmark.c | | | 212 | ++++++++++++++++++++++++++++++------------------------------------------------- |
| M | test/test_matmul.c | | | 424 | +++++++++++++++++++++++-------------------------------------------------------- |
| M | test/test_matmul_simd.h | | | 122 | +++++-------------------------------------------------------------------------- |
6 files changed, 417 insertions(+), 3899 deletions(-)
diff --git a/.gitignore b/.gitignore
@@ -2,3 +2,4 @@
*.o
/test_matmul
/.dep
+/benchmark
diff --git a/README.md b/README.md
@@ -1,6 +1,9 @@
-# Matmul - High-Performance Matrix Multiplication Library
+# Matmul - Accelerated Matrix Multiplication Library
-A lightweight, type-safe matrix multiplication library with compile-time dispatch, scaling support, and infrastructure for SIMD acceleration.
+A lightweight, type-safe matrix multiplication library with runtime dispatch and SIMD acceleration.
+
+**Current Implementation:** `uint8_t` × `int8_t` → `uint8_t`
+*(additional type combinations planned)*
## Installation
@@ -14,12 +17,11 @@ Alternatively, you can include the [matmul.c](src/matmul.c) and [matmul.h](src/m
## Features
-- **64 Type Combinations**: Supports all combinations of f32/f64/i8/u8 for input matrices A, B and output C
-- **Compile-Time Dispatch**: Uses C11 `_Generic` for zero-overhead type detection
-- **Scale Parameter**: Divide output before writing (e.g., scale=64 for quantization emulation)
-- **Future SIMD Ready**: Infrastructure in place for AVX2/AVX512/VNNI acceleration
+- **SIMD Acceleration**: AVX2, AVX512, AVX-VNNI, and AVX512-VNNI with automatic CPU feature detection
+- **Multi-Core Parallelization**: OpenMP-based parallel processing across matrix rows
+- **Tiled Algorithm**: Cache-blocking for improved locality at large matrix sizes
- **No Dependencies**: Pure C11 implementation with no external libraries
-- **Well Tested**: 64 unit tests covering all type combinations
+- **Well Tested**: Unit tests verify correctness across all implementations
## Quick Start
@@ -28,25 +30,25 @@ Alternatively, you can include the [matmul.c](src/matmul.c) and [matmul.h](src/m
#include <stdio.h>
int main() {
- // 2x3 matrix A
- float A[6] = {1, 2, 3, 4, 5, 6};
- // 3x2 matrix B
- float B[6] = {1, 0, 0, 1, 0, 0};
- // 2x2 result matrix C
- float C[4];
+ // 2x3 matrix A (uint8_t)
+ uint8_t A[6] = {1, 2, 3, 4, 5, 6};
+ // 3x2 matrix B (int8_t)
+ int8_t B[6] = {1, 0, 0, 1, 0, 0};
+ // 2x2 result matrix C (uint8_t)
+ uint8_t C[4];
// Multiply A(2x3) * B(3x2) = C(2x2)
matmul(2, 3, 2, A, B, C, 0.0); // scale=0: no scaling
// C should be [1, 2, 4, 5]
- printf("C[0] = %f, C[1] = %f, C[2] = %f, C[3] = %f\n",
+ printf("C[0] = %u, C[1] = %u, C[2] = %u, C[3] = %u\n",
C[0], C[1], C[2], C[3]);
return 0;
}
```
-Compile with: `cc -o example example.c -lm`
+Compile with: `cc -o example example.c -lm -fopenmp`
## API Reference
@@ -58,19 +60,17 @@ matmul(m, n, p, A, B, C, scale);
- `scale`: Divide each output element by this value before writing (0 or 1 = no scaling)
### Direct Function Calls
-Each of the 64 type combinations is available directly:
```c
-// Floating point
-matmul_f32_f32_f32(m, n, p, A, B, C, scale);
-matmul_f32_f32_f64(m, n, p, A, B, C, scale);
-matmul_f32_f64_f32(m, n, p, A, B, C, scale);
-// ... etc for all 64 combinations
-
-// Integer types
-matmul_i8_i8_i8(m, n, p, A, B, C, scale);
-matmul_u8_u8_u8(m, n, p, A, B, C, scale);
-matmul_i8_u8_i8(m, n, p, A, B, C, scale);
-// ... etc
+// Currently implemented type combination (u8 × i8 → u8)
+int matmul_u8_i8_u8(size_t m, size_t n, size_t p,
+ const uint8_t *A, const int8_t *B,
+ uint8_t *C, double scale);
+
+// Scalar and SIMD variants
+int matmul_scalar_u8_i8_u8(...);
+int matmul_avx2_u8_i8_u8(...);
+int matmul_avx512_u8_i8_u8(...);
+int matmul_avxvnni_u8_i8_u8(...);
```
### Type Naming Conventions
@@ -105,30 +105,45 @@ make
# Run tests
./test_matmul
+# Run benchmarks (optional, needs ~800MB RAM for 16K×16K)
+./benchmark
+
# Clean build artifacts
make clean
```
-Requires a C11 compiler (gcc, clang, MSVC).
+Requires a C11 compiler (gcc, clang, MSVC) with OpenMP support.
+
+## Performance
+
+On an AMD Ryzen 7 9800X3D (Zen 5, 16 cores), with AVX2/AVX-VNNI/AVX512/AVX512-VNNI enabled:
+
+| Matrix Size | Time (ms) | GFLOPS |
+|-------------|-----------|---------|
+| 1024×1024 | ~2.4 | ~880 |
+| 4096×4096 | ~173 | ~795 |
+| 16384×16384 | ~12970 | ~678 |
+
+The 16384×16384 case exceeds L3 cache (~768MB total), so performance is memory-bound. For optimal performance use matrices that fit in L3 (≤4096 on most CPUs).
## Testing
-The library includes 64 unit tests covering all type combinations:
+The library includes unit tests verifying correctness across all implementations:
- Run: `./test_matmul`
-- Tests verify correctness against reference implementations
-- Output shows PASS/FAIL status for each type combination
+- Tests verify correctness against reference scalar implementation
+- Output shows PASS/FAIL status for each implementation (scalar, AVX2, AVX512, dispatched)
-## Future Work
+## Implementation Notes
-SIMD acceleration infrastructure is already in place:
-- Auto-dispatch functions (`_matmul_*`) replace function pointers on first call
-- Ready for AVX2, AVX512, AVX512-VNNI, and AVX-VNNI implementations
-- When implemented, dispatch will select best available CPU features at runtime
-- Fallback chain: AVX512-VNNI → AVX512 → AVX2 → Scalar
+- **Automatic dispatch**: The first call runtime-detects CPU features and selects the optimal implementation
+- **Dispatch priority**: AVX512-VNNI → AVX512 → AVX-VNNI → AVX2 → Scalar
+- **Parallelization**: OpenMP `parallel for` with `static` scheduling across row blocks
+- **Tiling**: Blocking factors tuned for L1/L2 cache (ib=32/64, jb=64, kb=32/64 depending on SIMD width)
## License
Licensed under custom terms (Copyright 2026 finwo); see LICENSE.md for full details.
---
+
*Built with C11. Zero runtime overhead for type dispatch.*
diff --git a/src/matmul.c b/src/matmul.c
@@ -42,3371 +42,209 @@
#include <stdlib.h>
#include <string.h>
+#ifdef __AVX2__
+#include <immintrin.h>
+#endif
+
+#ifdef __AVX512F__
+#include <immintrin.h>
+#endif
+
+#ifdef __AVXVNNI__
+#include <immintrin.h>
+#endif
+
+#ifdef __AVX512VNNI__
+#include <immintrin.h>
+#endif
+
#ifdef _OPENMP
#include <omp.h>
#endif
-#define MATMUL_FLAG_SCALAR (1 << 0)
-#define MATMUL_FLAG_AVX2 (1 << 1)
-#define MATMUL_FLAG_AVX512 (1 << 2)
-#define MATMUL_FLAG_AVX512_VNNI (1 << 3)
-#define MATMUL_FLAG_AVXVNNI (1 << 4)
+#define MATMUL_FLAG_SCALAR (1 << 0)
+#define MATMUL_FLAG_AVX2 (1 << 1)
+#define MATMUL_FLAG_AVXVNNI (1 << 2)
+#define MATMUL_FLAG_AVX512 (1 << 3)
+#define MATMUL_FLAG_AVX512VNNI (1 << 4)
typedef uint32_t matmul_feature_t;
-static matmul_feature_t g_feature = 0;
+static matmul_feature_t g_feature = 0;
+static int g_initialized = 0;
static void init_feature(void) {
- if (g_feature != 0) return;
g_feature = MATMUL_FLAG_SCALAR;
-#ifdef __AVX2__
- if (__builtin_cpu_supports("avx2")) g_feature |= MATMUL_FLAG_AVX2;
+#ifdef __AVX512VNNI__
+ if (__builtin_cpu_supports("avx512vnni")) g_feature |= MATMUL_FLAG_AVX512VNNI;
#endif
#ifdef __AVX512F__
- if (__builtin_cpu_supports("avx512f")) {
- g_feature |= MATMUL_FLAG_AVX512;
- if (__builtin_cpu_supports("avx512vnni")) g_feature |= MATMUL_FLAG_AVX512_VNNI;
- }
+ if (__builtin_cpu_supports("avx512f")) g_feature |= MATMUL_FLAG_AVX512;
+#endif
+#ifdef __AVXVNNI__
+ if (__builtin_cpu_supports("avxvnni")) g_feature |= MATMUL_FLAG_AVXVNNI;
#endif
#ifdef __AVX2__
- if (__builtin_cpu_supports("avx2") && __builtin_cpu_supports("avxvnni")) g_feature |= MATMUL_FLAG_AVXVNNI;
+ if (__builtin_cpu_supports("avx2")) g_feature |= MATMUL_FLAG_AVX2;
#endif
}
matmul_feature_t matmul_get_feature(void) {
- init_feature();
+ if (!g_initialized) {
+ init_feature();
+ g_initialized = 1;
+ }
return g_feature;
}
const char *matmul_get_feature_name(matmul_feature_t feat) {
+ if (feat & MATMUL_FLAG_AVX512VNNI) return "avx512vnni";
if (feat & MATMUL_FLAG_AVX512) return "avx512";
+ if (feat & MATMUL_FLAG_AVXVNNI) return "avxvnni";
if (feat & MATMUL_FLAG_AVX2) return "avx2";
- if (feat & MATMUL_FLAG_AVX512_VNNI) return "avx512_vnni";
- if (feat & MATMUL_FLAG_AVXVNNI) return "avx_vnni";
if (feat & MATMUL_FLAG_SCALAR) return "scalar";
return "unknown";
}
-int matmul_scalar_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const float *B, float *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- float sum = 0;
- for (size_t k = 0; k < n; k++) {
- sum += A[i * n + k] * B[k * p + j];
- }
- if (scale != 0 && scale != 1) sum /= (float)scale;
- C[i * p + j] = sum;
- }
- }
- return 0;
-}
-
-int matmul_scalar_f32_f32_f64(size_t m, size_t n, size_t p, const float *A, const float *B, double *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- double sum = 0;
- for (size_t k = 0; k < n; k++) {
- sum += (double)A[i * n + k] * (double)B[k * p + j];
- }
- if (scale != 0 && scale != 1) sum /= scale;
- C[i * p + j] = sum;
- }
- }
- return 0;
-}
-
-int matmul_scalar_f32_f32_i8(size_t m, size_t n, size_t p, const float *A, const float *B, int8_t *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- int sum = 0;
- for (size_t k = 0; k < n; k++) {
- sum += (int)A[i * n + k] * (int)B[k * p + j];
- }
- if (scale != 0 && scale != 1) sum = (int)(sum / scale);
- if (sum > 127)
- sum = 127;
- else if (sum < -128)
- sum = -128;
- C[i * p + j] = (int8_t)sum;
- }
- }
- return 0;
-}
-
-int matmul_scalar_f32_f32_u8(size_t m, size_t n, size_t p, const float *A, const float *B, uint8_t *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- int sum = 0;
- for (size_t k = 0; k < n; k++) {
- sum += (int)A[i * n + k] * (int)B[k * p + j];
- }
- if (scale != 0 && scale != 1) sum = (int)(sum / scale);
- if (sum > 255)
- sum = 255;
- else if (sum < 0)
- sum = 0;
- C[i * p + j] = (uint8_t)sum;
- }
- }
- return 0;
-}
-
-int matmul_scalar_f32_f64_f32(size_t m, size_t n, size_t p, const float *A, const double *B, float *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- float sum = 0;
- for (size_t k = 0; k < n; k++) {
- sum += (float)A[i * n + k] * (float)B[k * p + j];
- }
- if (scale != 0 && scale != 1) sum /= (float)scale;
- C[i * p + j] = sum;
- }
- }
- return 0;
-}
-
-int matmul_scalar_f32_f64_f64(size_t m, size_t n, size_t p, const float *A, const double *B, double *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- double sum = 0;
- for (size_t k = 0; k < n; k++) {
- sum += (double)A[i * n + k] * B[k * p + j];
- }
- if (scale != 0 && scale != 1) sum /= scale;
- C[i * p + j] = sum;
- }
- }
- return 0;
-}
-
-int matmul_scalar_f32_f64_i8(size_t m, size_t n, size_t p, const float *A, const double *B, int8_t *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- int sum = 0;
- for (size_t k = 0; k < n; k++) {
- sum += (int)A[i * n + k] * (int)B[k * p + j];
- }
- if (scale != 0 && scale != 1) sum = (int)(sum / scale);
- if (sum > 127)
- sum = 127;
- else if (sum < -128)
- sum = -128;
- C[i * p + j] = (int8_t)sum;
- }
- }
- return 0;
-}
-
-int matmul_scalar_f32_f64_u8(size_t m, size_t n, size_t p, const float *A, const double *B, uint8_t *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- int sum = 0;
- for (size_t k = 0; k < n; k++) {
- sum += (int)A[i * n + k] * (int)B[k * p + j];
- }
- if (scale != 0 && scale != 1) sum = (int)(sum / scale);
- if (sum > 255)
- sum = 255;
- else if (sum < 0)
- sum = 0;
- C[i * p + j] = (uint8_t)sum;
- }
- }
- return 0;
-}
-
-int matmul_scalar_f32_i8_f32(size_t m, size_t n, size_t p, const float *A, const int8_t *B, float *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- float sum = 0;
- for (size_t k = 0; k < n; k++) {
- sum += (float)A[i * n + k] * (float)B[k * p + j];
- }
- if (scale != 0 && scale != 1) sum /= (float)scale;
- C[i * p + j] = sum;
- }
- }
- return 0;
-}
-
-int matmul_scalar_f32_i8_f64(size_t m, size_t n, size_t p, const float *A, const int8_t *B, double *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- double sum = 0;
- for (size_t k = 0; k < n; k++) {
- sum += (double)A[i * n + k] * (double)B[k * p + j];
- }
- if (scale != 0 && scale != 1) sum /= scale;
- C[i * p + j] = sum;
- }
- }
- return 0;
-}
-
-int matmul_scalar_f32_i8_i8(size_t m, size_t n, size_t p, const float *A, const int8_t *B, int8_t *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- int sum = 0;
- for (size_t k = 0; k < n; k++) {
- sum += (int)A[i * n + k] * (int)B[k * p + j];
- }
- if (scale != 0 && scale != 1) sum = (int)(sum / scale);
- if (sum > 127)
- sum = 127;
- else if (sum < -128)
- sum = -128;
- C[i * p + j] = (int8_t)sum;
- }
- }
- return 0;
-}
-
-int matmul_scalar_f32_i8_u8(size_t m, size_t n, size_t p, const float *A, const int8_t *B, uint8_t *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- int sum = 0;
- for (size_t k = 0; k < n; k++) {
- sum += (int)A[i * n + k] * (int)B[k * p + j];
- }
- if (scale != 0 && scale != 1) sum = (int)(sum / scale);
- if (sum > 255)
- sum = 255;
- else if (sum < 0)
- sum = 0;
- C[i * p + j] = (uint8_t)sum;
- }
- }
- return 0;
-}
-
-int matmul_scalar_f32_u8_f32(size_t m, size_t n, size_t p, const float *A, const uint8_t *B, float *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- float sum = 0;
- for (size_t k = 0; k < n; k++) {
- sum += (float)A[i * n + k] * (float)B[k * p + j];
- }
- if (scale != 0 && scale != 1) sum /= (float)scale;
- C[i * p + j] = sum;
- }
- }
- return 0;
-}
-
-int matmul_scalar_f32_u8_f64(size_t m, size_t n, size_t p, const float *A, const uint8_t *B, double *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- double sum = 0;
- for (size_t k = 0; k < n; k++) {
- sum += (double)A[i * n + k] * (double)B[k * p + j];
- }
- if (scale != 0 && scale != 1) sum /= scale;
- C[i * p + j] = sum;
- }
- }
- return 0;
-}
-
-int matmul_scalar_f32_u8_i8(size_t m, size_t n, size_t p, const float *A, const uint8_t *B, int8_t *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- int sum = 0;
- for (size_t k = 0; k < n; k++) {
- sum += (int)A[i * n + k] * (int)B[k * p + j];
- }
- if (scale != 0 && scale != 1) sum = (int)(sum / scale);
- if (sum > 127)
- sum = 127;
- else if (sum < -128)
- sum = -128;
- C[i * p + j] = (int8_t)sum;
- }
- }
- return 0;
-}
+int matmul_scalar_u8_i8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, uint8_t *C, double scale) {
+ (void)scale;
+ const size_t ib = 64;
+ const size_t jb = 64;
+ const size_t kb = 32;
+
+#pragma omp parallel for schedule(static)
+ for (size_t ii = 0; ii < m; ii += ib) {
+ size_t i_end = (ii + ib < m) ? ii + ib : m;
+ for (size_t jj = 0; jj < p; jj += jb) {
+ size_t j_end = (jj + jb < p) ? jj + jb : p;
+ size_t ti = i_end - ii;
+ size_t tj = j_end - jj;
+ int32_t acc[64 * 64];
+ memset(acc, 0, ti * tj * sizeof(int32_t));
+
+ for (size_t kk = 0; kk < n; kk += kb) {
+ size_t k_end = (kk + kb < n) ? kk + kb : n;
+ for (size_t i = ii; i < i_end; i++) {
+ size_t li = i - ii;
+ for (size_t j = jj; j < j_end; j++) {
+ size_t lj = j - jj;
+ int sum = 0;
+ for (size_t k = kk; k < k_end; k++) {
+ sum += (int)A[i * n + k] * (int)B[k * p + j];
+ }
+ acc[li * tj + lj] += sum;
+ }
+ }
+ }
+
+ for (size_t i = ii; i < i_end; i++) {
+ size_t li = i - ii;
+ for (size_t j = jj; j < j_end; j++) {
+ size_t lj = j - jj;
+ int v = acc[li * tj + lj];
+ if (v > 255)
+ v = 255;
+ else if (v < 0)
+ v = 0;
+ C[i * p + j] = (uint8_t)v;
+ }
+ }
+ }
+ }
+ return 0;
+}
+
+#ifdef __AVX512VNNI__
+int matmul_avx512vnni_u8_i8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, uint8_t *C,
+ double scale) {
+ (void)scale;
+ const size_t ib = 64;
+ const size_t jb = 64;
+ const size_t kb = 32;
-int matmul_scalar_f32_u8_u8(size_t m, size_t n, size_t p, const float *A, const uint8_t *B, uint8_t *C, double scale) {
-#ifdef _OPENMP
#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- int sum = 0;
- for (size_t k = 0; k < n; k++) {
- sum += (int)A[i * n + k] * (int)B[k * p + j];
- }
- if (scale != 0 && scale != 1) sum = (int)(sum / scale);
- if (sum > 255)
- sum = 255;
- else if (sum < 0)
- sum = 0;
- C[i * p + j] = (uint8_t)sum;
- }
- }
- return 0;
-}
+ for (size_t ii = 0; ii < m; ii += ib) {
+ size_t i_end = (ii + ib < m) ? ii + ib : m;
+ for (size_t jj = 0; jj < p; jj += jb) {
+ size_t j_end = (jj + jb < p) ? jj + jb : p;
+ size_t ti = i_end - ii;
+ size_t tj = j_end - jj;
+ int32_t acc[64 * 64];
+ memset(acc, 0, ti * tj * sizeof(int32_t));
-int matmul_scalar_f64_f32_f32(size_t m, size_t n, size_t p, const double *A, const float *B, float *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- float sum = 0;
- for (size_t k = 0; k < n; k++) {
- sum += (float)A[i * n + k] * (float)B[k * p + j];
- }
- if (scale != 0 && scale != 1) sum /= (float)scale;
- C[i * p + j] = sum;
- }
- }
- return 0;
-}
+ for (size_t kk = 0; kk < n; kk += kb) {
+ size_t k_limit = (kk + kb < n) ? kk + kb : n;
-int matmul_scalar_f64_f32_f64(size_t m, size_t n, size_t p, const double *A, const float *B, double *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- double sum = 0;
- for (size_t k = 0; k < n; k++) {
- sum += A[i * n + k] * (double)B[k * p + j];
- }
- if (scale != 0 && scale != 1) sum /= scale;
- C[i * p + j] = sum;
- }
- }
- return 0;
-}
+ for (size_t i = ii; i < i_end; i++) {
+ size_t li = i - ii;
+ int32_t *acc_row = &acc[li * tj];
-int matmul_scalar_f64_f32_i8(size_t m, size_t n, size_t p, const double *A, const float *B, int8_t *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- int sum = 0;
- for (size_t k = 0; k < n; k++) {
- sum += (int)A[i * n + k] * (int)B[k * p + j];
- }
- if (scale != 0 && scale != 1) sum = (int)(sum / scale);
- if (sum > 127)
- sum = 127;
- else if (sum < -128)
- sum = -128;
- C[i * p + j] = (int8_t)sum;
- }
- }
- return 0;
-}
+ for (size_t j = jj; j < j_end; j += 16) {
+ size_t j_chunk = (j + 16 <= j_end) ? 16 : (j_end - j);
+ __m512i result = _mm512_setzero_si512();
-int matmul_scalar_f64_f32_u8(size_t m, size_t n, size_t p, const double *A, const float *B, uint8_t *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- int sum = 0;
- for (size_t k = 0; k < n; k++) {
- sum += (int)A[i * n + k] * (int)B[k * p + j];
- }
- if (scale != 0 && scale != 1) sum = (int)(sum / scale);
- if (sum > 255)
- sum = 255;
- else if (sum < 0)
- sum = 0;
- C[i * p + j] = (uint8_t)sum;
- }
- }
- return 0;
-}
+ for (size_t k = kk; k < k_limit; k += 4) {
+ size_t k_chunk = (k + 4 <= k_limit) ? 4 : (k_limit - k);
-int matmul_scalar_f64_f64_f32(size_t m, size_t n, size_t p, const double *A, const double *B, float *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- float sum = 0;
- for (size_t k = 0; k < n; k++) {
- sum += (float)A[i * n + k] * (float)B[k * p + j];
- }
- if (scale != 0 && scale != 1) sum /= (float)scale;
- C[i * p + j] = sum;
- }
- }
- return 0;
-}
+ uint32_t a4 = 0;
+ for (size_t dk = 0; dk < k_chunk; dk++) {
+ a4 |= (uint32_t)A[i * n + k + dk] << (dk * 8);
+ }
+ __m512i a_val = _mm512_set1_epi32(a4);
-int matmul_scalar_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, const double *B, double *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- double sum = 0;
- for (size_t k = 0; k < n; k++) {
- sum += A[i * n + k] * B[k * p + j];
- }
- if (scale != 0 && scale != 1) sum /= scale;
- C[i * p + j] = sum;
- }
- }
- return 0;
-}
+ int8_t b_buf[64] = {0};
+ for (size_t dj = 0; dj < j_chunk; dj++) {
+ for (size_t dk = 0; dk < k_chunk; dk++) {
+ b_buf[dj * 4 + dk] = B[(k + dk) * p + j + dj];
+ }
+ }
+ __m512i b_val = _mm512_load_si512((__m512i const *)b_buf);
-int matmul_scalar_f64_f64_i8(size_t m, size_t n, size_t p, const double *A, const double *B, int8_t *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- int sum = 0;
- for (size_t k = 0; k < n; k++) {
- sum += (int)A[i * n + k] * (int)B[k * p + j];
- }
- if (scale != 0 && scale != 1) sum = (int)(sum / scale);
- if (sum > 127)
- sum = 127;
- else if (sum < -128)
- sum = -128;
- C[i * p + j] = (int8_t)sum;
- }
- }
- return 0;
-}
+ result = _mm512_dpbusd_epi32(result, a_val, b_val);
+ }
-int matmul_scalar_f64_f64_u8(size_t m, size_t n, size_t p, const double *A, const double *B, uint8_t *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- int sum = 0;
- for (size_t k = 0; k < n; k++) {
- sum += (int)A[i * n + k] * (int)B[k * p + j];
- }
- if (scale != 0 && scale != 1) sum = (int)(sum / scale);
- if (sum > 255)
- sum = 255;
- else if (sum < 0)
- sum = 0;
- C[i * p + j] = (uint8_t)sum;
- }
- }
- return 0;
-}
+ int32_t tmp[16] __attribute__((aligned(64)));
+ _mm512_store_si512(tmp, result);
-int matmul_scalar_f64_i8_f32(size_t m, size_t n, size_t p, const double *A, const int8_t *B, float *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- float sum = 0;
- for (size_t k = 0; k < n; k++) {
- sum += (float)A[i * n + k] * (float)B[k * p + j];
+ size_t j_offset = j - jj;
+ for (size_t c = 0; c < j_chunk; c++) {
+ acc_row[j_offset + c] += tmp[c];
+ }
+ }
+ }
}
- if (scale != 0 && scale != 1) sum /= (float)scale;
- C[i * p + j] = sum;
- }
- }
- return 0;
-}
-int matmul_scalar_f64_i8_f64(size_t m, size_t n, size_t p, const double *A, const int8_t *B, double *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- double sum = 0;
- for (size_t k = 0; k < n; k++) {
- sum += A[i * n + k] * (double)B[k * p + j];
+ for (size_t i = ii; i < i_end; i++) {
+ size_t li = i - ii;
+ for (size_t j = jj; j < j_end; j++) {
+ size_t lj = j - jj;
+ int32_t v = acc[li * tj + lj];
+ if (v > 255)
+ v = 255;
+ else if (v < 0)
+ v = 0;
+ C[i * p + j] = (uint8_t)v;
+ }
}
- if (scale != 0 && scale != 1) sum /= scale;
- C[i * p + j] = sum;
}
}
return 0;
}
-
-int matmul_scalar_f64_i8_i8(size_t m, size_t n, size_t p, const double *A, const int8_t *B, int8_t *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- int sum = 0;
- for (size_t k = 0; k < n; k++) {
- sum += (int)A[i * n + k] * (int)B[k * p + j];
- }
- if (scale != 0 && scale != 1) sum = (int)(sum / scale);
- if (sum > 127)
- sum = 127;
- else if (sum < -128)
- sum = -128;
- C[i * p + j] = (int8_t)sum;
- }
- }
- return 0;
-}
-int matmul_scalar_f64_i8_u8(size_t m, size_t n, size_t p, const double *A, const int8_t *B, uint8_t *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
+static int _matmul_u8_i8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, uint8_t *C, double scale) {
+ static int initialized = 0;
+ if (!initialized) {
+ matmul_feature_t feat = matmul_get_feature();
+#ifdef __AVX512VNNI__
+ if (feat & MATMUL_FLAG_AVX512VNNI)
+ matmul_u8_i8_u8 = matmul_avx512vnni_u8_i8_u8;
+ else
#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- int sum = 0;
- for (size_t k = 0; k < n; k++) {
- sum += (int)A[i * n + k] * (int)B[k * p + j];
- }
- if (scale != 0 && scale != 1) sum = (int)(sum / scale);
- if (sum > 255)
- sum = 255;
- else if (sum < 0)
- sum = 0;
- C[i * p + j] = (uint8_t)sum;
- }
- }
- return 0;
-}
-
-int matmul_scalar_f64_u8_f32(size_t m, size_t n, size_t p, const double *A, const uint8_t *B, float *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- float sum = 0;
- for (size_t k = 0; k < n; k++) {
- sum += (float)A[i * n + k] * (float)B[k * p + j];
- }
- if (scale != 0 && scale != 1) sum /= (float)scale;
- C[i * p + j] = sum;
- }
- }
- return 0;
-}
-
-int matmul_scalar_f64_u8_f64(size_t m, size_t n, size_t p, const double *A, const uint8_t *B, double *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- double sum = 0;
- for (size_t k = 0; k < n; k++) {
- sum += A[i * n + k] * (double)B[k * p + j];
- }
- if (scale != 0 && scale != 1) sum /= scale;
- C[i * p + j] = sum;
- }
- }
- return 0;
-}
-
-int matmul_scalar_f64_u8_i8(size_t m, size_t n, size_t p, const double *A, const uint8_t *B, int8_t *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- int sum = 0;
- for (size_t k = 0; k < n; k++) {
- sum += (int)A[i * n + k] * (int)B[k * p + j];
- }
- if (scale != 0 && scale != 1) sum = (int)(sum / scale);
- if (sum > 127)
- sum = 127;
- else if (sum < -128)
- sum = -128;
- C[i * p + j] = (int8_t)sum;
- }
- }
- return 0;
-}
-
-int matmul_scalar_f64_u8_u8(size_t m, size_t n, size_t p, const double *A, const uint8_t *B, uint8_t *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- int sum = 0;
- for (size_t k = 0; k < n; k++) {
- sum += (int)A[i * n + k] * (int)B[k * p + j];
- }
- if (scale != 0 && scale != 1) sum = (int)(sum / scale);
- if (sum > 255)
- sum = 255;
- else if (sum < 0)
- sum = 0;
- C[i * p + j] = (uint8_t)sum;
- }
- }
- return 0;
-}
-
-int matmul_scalar_i8_f32_f32(size_t m, size_t n, size_t p, const int8_t *A, const float *B, float *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- float sum = 0;
- for (size_t k = 0; k < n; k++) {
- sum += (float)A[i * n + k] * (float)B[k * p + j];
- }
- if (scale != 0 && scale != 1) sum /= (float)scale;
- C[i * p + j] = sum;
- }
- }
- return 0;
-}
-
-int matmul_scalar_i8_f32_f64(size_t m, size_t n, size_t p, const int8_t *A, const float *B, double *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- double sum = 0;
- for (size_t k = 0; k < n; k++) {
- sum += (double)A[i * n + k] * (double)B[k * p + j];
- }
- if (scale != 0 && scale != 1) sum /= scale;
- C[i * p + j] = sum;
- }
- }
- return 0;
-}
-
-int matmul_scalar_i8_f32_i8(size_t m, size_t n, size_t p, const int8_t *A, const float *B, int8_t *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- int sum = 0;
- for (size_t k = 0; k < n; k++) {
- sum += (int)A[i * n + k] * (int)B[k * p + j];
- }
- if (scale != 0 && scale != 1) sum = (int)(sum / scale);
- if (sum > 127)
- sum = 127;
- else if (sum < -128)
- sum = -128;
- C[i * p + j] = (int8_t)sum;
- }
- }
- return 0;
-}
-
-int matmul_scalar_i8_f32_u8(size_t m, size_t n, size_t p, const int8_t *A, const float *B, uint8_t *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- int sum = 0;
- for (size_t k = 0; k < n; k++) {
- sum += (int)A[i * n + k] * (int)B[k * p + j];
- }
- if (scale != 0 && scale != 1) sum = (int)(sum / scale);
- if (sum > 255)
- sum = 255;
- else if (sum < 0)
- sum = 0;
- C[i * p + j] = (uint8_t)sum;
- }
- }
- return 0;
-}
-
-int matmul_scalar_i8_f64_f32(size_t m, size_t n, size_t p, const int8_t *A, const double *B, float *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- float sum = 0;
- for (size_t k = 0; k < n; k++) {
- sum += (float)A[i * n + k] * (float)B[k * p + j];
- }
- if (scale != 0 && scale != 1) sum /= (float)scale;
- C[i * p + j] = sum;
- }
- }
- return 0;
-}
-
-int matmul_scalar_i8_f64_f64(size_t m, size_t n, size_t p, const int8_t *A, const double *B, double *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- double sum = 0;
- for (size_t k = 0; k < n; k++) {
- sum += (double)A[i * n + k] * B[k * p + j];
- }
- if (scale != 0 && scale != 1) sum /= scale;
- C[i * p + j] = sum;
- }
- }
- return 0;
-}
-
-int matmul_scalar_i8_f64_i8(size_t m, size_t n, size_t p, const int8_t *A, const double *B, int8_t *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- int sum = 0;
- for (size_t k = 0; k < n; k++) {
- sum += (int)A[i * n + k] * (int)B[k * p + j];
- }
- if (scale != 0 && scale != 1) sum = (int)(sum / scale);
- if (sum > 127)
- sum = 127;
- else if (sum < -128)
- sum = -128;
- C[i * p + j] = (int8_t)sum;
- }
- }
- return 0;
-}
-
-int matmul_scalar_i8_f64_u8(size_t m, size_t n, size_t p, const int8_t *A, const double *B, uint8_t *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- int sum = 0;
- for (size_t k = 0; k < n; k++) {
- sum += (int)A[i * n + k] * (int)B[k * p + j];
- }
- if (scale != 0 && scale != 1) sum = (int)(sum / scale);
- if (sum > 255)
- sum = 255;
- else if (sum < 0)
- sum = 0;
- C[i * p + j] = (uint8_t)sum;
- }
- }
- return 0;
-}
-
-int matmul_scalar_i8_i8_f32(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, float *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- float sum = 0;
- for (size_t k = 0; k < n; k++) {
- sum += (float)A[i * n + k] * (float)B[k * p + j];
- }
- if (scale != 0 && scale != 1) sum /= (float)scale;
- C[i * p + j] = sum;
- }
- }
- return 0;
-}
-
-int matmul_scalar_i8_i8_f64(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, double *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- double sum = 0;
- for (size_t k = 0; k < n; k++) {
- sum += (double)A[i * n + k] * (double)B[k * p + j];
- }
- if (scale != 0 && scale != 1) sum /= scale;
- C[i * p + j] = sum;
- }
- }
- return 0;
-}
-
-int matmul_scalar_i8_i8_i8(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, int8_t *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- int sum = 0;
- for (size_t k = 0; k < n; k++) {
- sum += (int)A[i * n + k] * (int)B[k * p + j];
- }
- if (scale != 0 && scale != 1) sum = (int)(sum / scale);
- if (sum > 127)
- sum = 127;
- else if (sum < -128)
- sum = -128;
- C[i * p + j] = (int8_t)sum;
- }
- }
- return 0;
-}
-
-int matmul_scalar_i8_i8_u8(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, uint8_t *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- int sum = 0;
- for (size_t k = 0; k < n; k++) {
- sum += (int)A[i * n + k] * (int)B[k * p + j];
- }
- if (scale != 0 && scale != 1) sum = (int)(sum / scale);
- if (sum > 255)
- sum = 255;
- else if (sum < 0)
- sum = 0;
- C[i * p + j] = (uint8_t)sum;
- }
- }
- return 0;
-}
-
-int matmul_scalar_i8_u8_f32(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, float *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- float sum = 0;
- for (size_t k = 0; k < n; k++) {
- sum += (float)A[i * n + k] * (float)B[k * p + j];
- }
- if (scale != 0 && scale != 1) sum /= (float)scale;
- C[i * p + j] = sum;
- }
- }
- return 0;
-}
-
-int matmul_scalar_i8_u8_f64(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, double *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- double sum = 0;
- for (size_t k = 0; k < n; k++) {
- sum += (double)A[i * n + k] * (double)B[k * p + j];
- }
- if (scale != 0 && scale != 1) sum /= scale;
- C[i * p + j] = sum;
- }
- }
- return 0;
-}
-
-int matmul_scalar_i8_u8_i8(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, int8_t *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- int sum = 0;
- for (size_t k = 0; k < n; k++) {
- sum += (int)A[i * n + k] * (int)B[k * p + j];
- }
- if (scale != 0 && scale != 1) sum = (int)(sum / scale);
- if (sum > 127)
- sum = 127;
- else if (sum < -128)
- sum = -128;
- C[i * p + j] = (int8_t)sum;
- }
- }
- return 0;
-}
-
-int matmul_scalar_i8_u8_u8(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, uint8_t *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- int sum = 0;
- for (size_t k = 0; k < n; k++) {
- sum += (int)A[i * n + k] * (int)B[k * p + j];
- }
- if (scale != 0 && scale != 1) sum = (int)(sum / scale);
- if (sum > 255)
- sum = 255;
- else if (sum < 0)
- sum = 0;
- C[i * p + j] = (uint8_t)sum;
- }
- }
- return 0;
-}
-
-int matmul_scalar_u8_f32_f32(size_t m, size_t n, size_t p, const uint8_t *A, const float *B, float *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- float sum = 0;
- for (size_t k = 0; k < n; k++) {
- sum += (float)A[i * n + k] * (float)B[k * p + j];
- }
- if (scale != 0 && scale != 1) sum /= (float)scale;
- C[i * p + j] = sum;
- }
- }
- return 0;
-}
-
-int matmul_scalar_u8_f32_f64(size_t m, size_t n, size_t p, const uint8_t *A, const float *B, double *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- double sum = 0;
- for (size_t k = 0; k < n; k++) {
- sum += (double)A[i * n + k] * (double)B[k * p + j];
- }
- if (scale != 0 && scale != 1) sum /= scale;
- C[i * p + j] = sum;
- }
- }
- return 0;
-}
-
-int matmul_scalar_u8_f32_i8(size_t m, size_t n, size_t p, const uint8_t *A, const float *B, int8_t *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- int sum = 0;
- for (size_t k = 0; k < n; k++) {
- sum += (int)A[i * n + k] * (int)B[k * p + j];
- }
- if (scale != 0 && scale != 1) sum = (int)(sum / scale);
- if (sum > 127)
- sum = 127;
- else if (sum < -128)
- sum = -128;
- C[i * p + j] = (int8_t)sum;
- }
- }
- return 0;
-}
-
-int matmul_scalar_u8_f32_u8(size_t m, size_t n, size_t p, const uint8_t *A, const float *B, uint8_t *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- int sum = 0;
- for (size_t k = 0; k < n; k++) {
- sum += (int)A[i * n + k] * (int)B[k * p + j];
- }
- if (scale != 0 && scale != 1) sum = (int)(sum / scale);
- if (sum > 255)
- sum = 255;
- else if (sum < 0)
- sum = 0;
- C[i * p + j] = (uint8_t)sum;
- }
- }
- return 0;
-}
-
-int matmul_scalar_u8_f64_f32(size_t m, size_t n, size_t p, const uint8_t *A, const double *B, float *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- float sum = 0;
- for (size_t k = 0; k < n; k++) {
- sum += (float)A[i * n + k] * (float)B[k * p + j];
- }
- if (scale != 0 && scale != 1) sum /= (float)scale;
- C[i * p + j] = sum;
- }
- }
- return 0;
-}
-
-int matmul_scalar_u8_f64_f64(size_t m, size_t n, size_t p, const uint8_t *A, const double *B, double *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- double sum = 0;
- for (size_t k = 0; k < n; k++) {
- sum += (double)A[i * n + k] * B[k * p + j];
- }
- if (scale != 0 && scale != 1) sum /= scale;
- C[i * p + j] = sum;
- }
- }
- return 0;
-}
-
-int matmul_scalar_u8_f64_i8(size_t m, size_t n, size_t p, const uint8_t *A, const double *B, int8_t *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- int sum = 0;
- for (size_t k = 0; k < n; k++) {
- sum += (int)A[i * n + k] * (int)B[k * p + j];
- }
- if (scale != 0 && scale != 1) sum = (int)(sum / scale);
- if (sum > 127)
- sum = 127;
- else if (sum < -128)
- sum = -128;
- C[i * p + j] = (int8_t)sum;
- }
- }
- return 0;
-}
-
-int matmul_scalar_u8_f64_u8(size_t m, size_t n, size_t p, const uint8_t *A, const double *B, uint8_t *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- int sum = 0;
- for (size_t k = 0; k < n; k++) {
- sum += (int)A[i * n + k] * (int)B[k * p + j];
- }
- if (scale != 0 && scale != 1) sum = (int)(sum / scale);
- if (sum > 255)
- sum = 255;
- else if (sum < 0)
- sum = 0;
- C[i * p + j] = (uint8_t)sum;
- }
- }
- return 0;
-}
-
-int matmul_scalar_u8_i8_f32(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, float *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- float sum = 0;
- for (size_t k = 0; k < n; k++) {
- sum += (float)A[i * n + k] * (float)B[k * p + j];
- }
- if (scale != 0 && scale != 1) sum /= (float)scale;
- C[i * p + j] = sum;
- }
- }
- return 0;
-}
-
-int matmul_scalar_u8_i8_f64(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, double *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- double sum = 0;
- for (size_t k = 0; k < n; k++) {
- sum += (double)A[i * n + k] * (double)B[k * p + j];
- }
- if (scale != 0 && scale != 1) sum /= scale;
- C[i * p + j] = sum;
- }
- }
- return 0;
-}
-
-int matmul_scalar_u8_i8_i8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, int8_t *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- int sum = 0;
- for (size_t k = 0; k < n; k++) {
- sum += (int)A[i * n + k] * (int)B[k * p + j];
- }
- if (scale != 0 && scale != 1) sum = (int)(sum / scale);
- if (sum > 127)
- sum = 127;
- else if (sum < -128)
- sum = -128;
- C[i * p + j] = (int8_t)sum;
- }
- }
- return 0;
-}
-
-int matmul_scalar_u8_i8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, uint8_t *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- int sum = 0;
- for (size_t k = 0; k < n; k++) {
- sum += (int)A[i * n + k] * (int)B[k * p + j];
- }
- if (scale != 0 && scale != 1) sum = (int)(sum / scale);
- if (sum > 255)
- sum = 255;
- else if (sum < 0)
- sum = 0;
- C[i * p + j] = (uint8_t)sum;
- }
- }
- return 0;
-}
-
-int matmul_scalar_u8_u8_f32(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, float *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- float sum = 0;
- for (size_t k = 0; k < n; k++) {
- sum += (float)A[i * n + k] * (float)B[k * p + j];
- }
- if (scale != 0 && scale != 1) sum /= (float)scale;
- C[i * p + j] = sum;
- }
- }
- return 0;
-}
-
-int matmul_scalar_u8_u8_f64(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, double *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- double sum = 0;
- for (size_t k = 0; k < n; k++) {
- sum += (double)A[i * n + k] * (double)B[k * p + j];
- }
- if (scale != 0 && scale != 1) sum /= scale;
- C[i * p + j] = sum;
- }
- }
- return 0;
-}
-
-int matmul_scalar_u8_u8_i8(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, int8_t *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- int sum = 0;
- for (size_t k = 0; k < n; k++) {
- sum += (int)A[i * n + k] * (int)B[k * p + j];
- }
- if (scale != 0 && scale != 1) sum = (int)(sum / scale);
- if (sum > 127)
- sum = 127;
- else if (sum < -128)
- sum = -128;
- C[i * p + j] = (int8_t)sum;
- }
- }
- return 0;
-}
-
-#ifdef __AVX2__
-int matmul_avx2_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const float *B, float *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- __m256 sum_vec = _mm256_setzero_ps();
- size_t k = 0;
- for (; k + 7 < n; k += 8) {
- __m256 a = _mm256_loadu_ps(&A[i * n + k]);
- __m256 b = _mm256_loadu_ps(&B[k * p + j]);
- sum_vec = _mm256_fmadd_ps(a, b, sum_vec);
- }
- float sum[8];
- _mm256_storeu_ps(sum, sum_vec);
- float s = sum[0] + sum[1] + sum[2] + sum[3] + sum[4] + sum[5] + sum[6] + sum[7];
- for (; k < n; k++) s += A[i * n + k] * B[k * p + j];
- if (scale != 0 && scale != 1) s /= (float)scale;
- C[i * p + j] = s;
- }
- }
- return 0;
-}
-
-int matmul_avx2_f32_f32_f64(size_t m, size_t n, size_t p, const float *A, const float *B, double *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- __m256d sum_vec = _mm256_setzero_pd();
- size_t k = 0;
- for (; k + 3 < n; k += 4) {
- __m256 a = _mm256_castpd_ps(_mm256_loadu_pd((const double *)&A[i * n + k]));
- __m256 b = _mm256_castpd_ps(_mm256_loadu_pd((const double *)&B[k * p + j]));
- __m256 mul = _mm256_mul_ps(a, b);
- sum_vec = _mm256_add_pd(_mm256_castps_pd(mul), sum_vec);
- }
- double sum[4];
- _mm256_storeu_pd(sum, sum_vec);
- double s = sum[0] + sum[1] + sum[2] + sum[3];
- for (; k < n; k++) s += (double)A[i * n + k] * (double)B[k * p + j];
- if (scale != 0 && scale != 1) s /= scale;
- C[i * p + j] = s;
- }
- }
- return 0;
-}
-
-int matmul_avx2_f32_f64_f32(size_t m, size_t n, size_t p, const float *A, const double *B, float *C, double scale) {
- return matmul_scalar_f32_f64_f32(m, n, p, A, B, C, scale);
-}
-
-int matmul_avx2_f32_f64_f64(size_t m, size_t n, size_t p, const float *A, const double *B, double *C, double scale) {
- return matmul_scalar_f32_f64_f64(m, n, p, A, B, C, scale);
-}
-
-int matmul_avx2_f64_f32_f32(size_t m, size_t n, size_t p, const double *A, const float *B, float *C, double scale) {
- return matmul_scalar_f64_f32_f32(m, n, p, A, B, C, scale);
-}
-
-int matmul_avx2_f64_f32_f64(size_t m, size_t n, size_t p, const double *A, const float *B, double *C, double scale) {
- return matmul_scalar_f64_f32_f64(m, n, p, A, B, C, scale);
-}
-
-int matmul_avx2_f64_f64_f32(size_t m, size_t n, size_t p, const double *A, const double *B, float *C, double scale) {
- return matmul_scalar_f64_f64_f32(m, n, p, A, B, C, scale);
-}
-
-int matmul_avx2_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, const double *B, double *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- __m256d sum_vec = _mm256_setzero_pd();
- size_t k = 0;
- for (; k + 3 < n; k += 4) {
- __m256d a = _mm256_loadu_pd(&A[i * n + k]);
- __m256d b = _mm256_loadu_pd(&B[k * p + j]);
- sum_vec = _mm256_fmadd_pd(a, b, sum_vec);
- }
- double sum[4];
- _mm256_storeu_pd(sum, sum_vec);
- double s = sum[0] + sum[1] + sum[2] + sum[3];
- for (; k < n; k++) s += A[i * n + k] * B[k * p + j];
- if (scale != 0 && scale != 1) s /= scale;
- C[i * p + j] = s;
- }
- }
- return 0;
-}
-#endif
-
-#ifdef __AVX512F__
-int matmul_avx512_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const float *B, float *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- __m512 sum_vec = _mm512_setzero_ps();
- size_t k = 0;
- for (; k + 15 < n; k += 16) {
- __m512 a = _mm512_loadu_ps(&A[i * n + k]);
- __m512 b = _mm512_loadu_ps(&B[k * p + j]);
- sum_vec = _mm512_fmadd_ps(a, b, sum_vec);
- }
- float s = _mm512_reduce_add_ps(sum_vec);
- for (; k < n; k++) s += A[i * n + k] * B[k * p + j];
- if (scale != 0 && scale != 1) s /= (float)scale;
- C[i * p + j] = s;
- }
- }
- return 0;
-}
-
-int matmul_avx512_f32_f32_f64(size_t m, size_t n, size_t p, const float *A, const float *B, double *C, double scale) {
- return matmul_scalar_f32_f32_f64(m, n, p, A, B, C, scale);
-}
-
-int matmul_avx512_f32_f64_f32(size_t m, size_t n, size_t p, const float *A, const double *B, float *C, double scale) {
- return matmul_scalar_f32_f64_f32(m, n, p, A, B, C, scale);
-}
-
-int matmul_avx512_f32_f64_f64(size_t m, size_t n, size_t p, const float *A, const double *B, double *C, double scale) {
- return matmul_scalar_f32_f64_f64(m, n, p, A, B, C, scale);
-}
-
-int matmul_avx512_f64_f32_f32(size_t m, size_t n, size_t p, const double *A, const float *B, float *C, double scale) {
- return matmul_scalar_f64_f32_f32(m, n, p, A, B, C, scale);
-}
-
-int matmul_avx512_f64_f32_f64(size_t m, size_t n, size_t p, const double *A, const float *B, double *C, double scale) {
- return matmul_scalar_f64_f32_f64(m, n, p, A, B, C, scale);
-}
-
-int matmul_avx512_f64_f64_f32(size_t m, size_t n, size_t p, const double *A, const double *B, float *C, double scale) {
- return matmul_scalar_f64_f64_f32(m, n, p, A, B, C, scale);
-}
-
-int matmul_avx512_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, const double *B, double *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- __m512d sum_vec = _mm512_setzero_pd();
- size_t k = 0;
- for (; k + 7 < n; k += 8) {
- __m512d a = _mm512_loadu_pd(&A[i * n + k]);
- __m512d b = _mm512_loadu_pd(&B[k * p + j]);
- sum_vec = _mm512_fmadd_pd(a, b, sum_vec);
- }
- double s = _mm512_reduce_add_pd(sum_vec);
- for (; k < n; k++) s += A[i * n + k] * B[k * p + j];
- if (scale != 0 && scale != 1) s /= scale;
- C[i * p + j] = s;
- }
- }
- return 0;
-}
-#endif
-
-#ifdef __AVX2__
-static inline int32_t reduce_add_i32x8(__m256i v) {
- __m128i low = _mm256_extracti128_si256(v, 0);
- __m128i high = _mm256_extracti128_si256(v, 1);
- __m128i sum = _mm_add_epi32(low, high);
- sum = _mm_hadd_epi32(sum, sum);
- sum = _mm_hadd_epi32(sum, sum);
- return _mm_cvtsi128_si32(sum);
-}
-
-int matmul_avx2_i8_i8_f32(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, float *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- __m256i sum_vec = _mm256_setzero_si256();
- size_t k = 0;
- for (; k + 31 < n; k += 32) {
- __m256i a_lo = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k]));
- __m256i a_hi = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16]));
- __m256i b_lo = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j]));
- __m256i b_hi = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16]));
- __m256i mul_lo = _mm256_madd_epi16(a_lo, b_lo);
- __m256i mul_hi = _mm256_madd_epi16(a_hi, b_hi);
- sum_vec = _mm256_add_epi32(sum_vec, mul_lo);
- sum_vec = _mm256_add_epi32(sum_vec, mul_hi);
- }
- int s = reduce_add_i32x8(sum_vec);
- for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j];
- if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale);
- C[i * p + j] = (float)s;
- }
- }
- return 0;
-}
-
-int matmul_avx2_i8_i8_f64(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, double *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- __m256i sum_vec = _mm256_setzero_si256();
- size_t k = 0;
- for (; k + 31 < n; k += 32) {
- __m256i a_lo = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k]));
- __m256i a_hi = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16]));
- __m256i b_lo = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j]));
- __m256i b_hi = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16]));
- __m256i mul_lo = _mm256_madd_epi16(a_lo, b_lo);
- __m256i mul_hi = _mm256_madd_epi16(a_hi, b_hi);
- sum_vec = _mm256_add_epi32(sum_vec, mul_lo);
- sum_vec = _mm256_add_epi32(sum_vec, mul_hi);
- }
- int32_t sum[8];
- _mm256_storeu_si256((__m256i *)sum, sum_vec);
- int s = reduce_add_i32x8(sum_vec);
- for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j];
- if (scale != 0 && scale != 1) s = (int)((double)s / scale);
- C[i * p + j] = (double)s;
- }
- }
- return 0;
-}
-
-int matmul_avx2_i8_i8_i8(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, int8_t *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- __m256i sum_vec = _mm256_setzero_si256();
- size_t k = 0;
- for (; k + 31 < n; k += 32) {
- __m256i a_lo = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k]));
- __m256i a_hi = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16]));
- __m256i b_lo = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j]));
- __m256i b_hi = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16]));
- __m256i mul_lo = _mm256_madd_epi16(a_lo, b_lo);
- __m256i mul_hi = _mm256_madd_epi16(a_hi, b_hi);
- sum_vec = _mm256_add_epi32(sum_vec, mul_lo);
- sum_vec = _mm256_add_epi32(sum_vec, mul_hi);
- }
- int32_t sum[8];
- _mm256_storeu_si256((__m256i *)sum, sum_vec);
- int s = reduce_add_i32x8(sum_vec);
- for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j];
- if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale);
- if (s > 127)
- s = 127;
- else if (s < -128)
- s = -128;
- C[i * p + j] = (int8_t)s;
- }
- }
- return 0;
-}
-
-int matmul_avx2_i8_i8_u8(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, uint8_t *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- __m256i sum_vec = _mm256_setzero_si256();
- size_t k = 0;
- for (; k + 31 < n; k += 32) {
- __m256i a_lo = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k]));
- __m256i a_hi = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16]));
- __m256i b_lo = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j]));
- __m256i b_hi = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16]));
- __m256i mul_lo = _mm256_madd_epi16(a_lo, b_lo);
- __m256i mul_hi = _mm256_madd_epi16(a_hi, b_hi);
- sum_vec = _mm256_add_epi32(sum_vec, mul_lo);
- sum_vec = _mm256_add_epi32(sum_vec, mul_hi);
- }
- int32_t sum[8];
- _mm256_storeu_si256((__m256i *)sum, sum_vec);
- int s = reduce_add_i32x8(sum_vec);
- for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j];
- if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale);
- if (s > 255)
- s = 255;
- else if (s < 0)
- s = 0;
- C[i * p + j] = (uint8_t)s;
- }
- }
- return 0;
-}
-
-int matmul_avx2_i8_u8_f32(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, float *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- __m256i sum_vec = _mm256_setzero_si256();
- size_t k = 0;
- for (; k + 31 < n; k += 32) {
- __m256i a_lo = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k]));
- __m256i a_hi = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16]));
- __m256i b_lo = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j]));
- __m256i b_hi = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16]));
- __m256i mul_lo = _mm256_madd_epi16(a_lo, b_lo);
- __m256i mul_hi = _mm256_madd_epi16(a_hi, b_hi);
- sum_vec = _mm256_add_epi32(sum_vec, mul_lo);
- sum_vec = _mm256_add_epi32(sum_vec, mul_hi);
- }
- int s = reduce_add_i32x8(sum_vec);
- for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j];
- if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale);
- C[i * p + j] = (float)s;
- }
- }
- return 0;
-}
-
-int matmul_avx2_i8_u8_f64(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, double *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- __m256i sum_vec = _mm256_setzero_si256();
- size_t k = 0;
- for (; k + 31 < n; k += 32) {
- __m256i a_lo = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k]));
- __m256i a_hi = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16]));
- __m256i b_lo = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j]));
- __m256i b_hi = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16]));
- __m256i mul_lo = _mm256_madd_epi16(a_lo, b_lo);
- __m256i mul_hi = _mm256_madd_epi16(a_hi, b_hi);
- sum_vec = _mm256_add_epi32(sum_vec, mul_lo);
- sum_vec = _mm256_add_epi32(sum_vec, mul_hi);
- }
- int s = reduce_add_i32x8(sum_vec);
- for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j];
- if (scale != 0 && scale != 1) s = (int)((double)s / scale);
- C[i * p + j] = (double)s;
- }
- }
- return 0;
-}
-
-int matmul_avx2_i8_u8_i8(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, int8_t *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- __m256i sum_vec = _mm256_setzero_si256();
- size_t k = 0;
- for (; k + 31 < n; k += 32) {
- __m256i a_lo = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k]));
- __m256i a_hi = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16]));
- __m256i b_lo = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j]));
- __m256i b_hi = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16]));
- __m256i mul_lo = _mm256_madd_epi16(a_lo, b_lo);
- __m256i mul_hi = _mm256_madd_epi16(a_hi, b_hi);
- sum_vec = _mm256_add_epi32(sum_vec, mul_lo);
- sum_vec = _mm256_add_epi32(sum_vec, mul_hi);
- }
- int s = reduce_add_i32x8(sum_vec);
- for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j];
- if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale);
- if (s > 127)
- s = 127;
- else if (s < -128)
- s = -128;
- C[i * p + j] = (int8_t)s;
- }
- }
- return 0;
-}
-
-int matmul_avx2_i8_u8_u8(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, uint8_t *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- __m256i sum_vec = _mm256_setzero_si256();
- size_t k = 0;
- for (; k + 31 < n; k += 32) {
- __m256i a_lo = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k]));
- __m256i a_hi = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16]));
- __m256i b_lo = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j]));
- __m256i b_hi = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16]));
- __m256i mul_lo = _mm256_madd_epi16(a_lo, b_lo);
- __m256i mul_hi = _mm256_madd_epi16(a_hi, b_hi);
- sum_vec = _mm256_add_epi32(sum_vec, mul_lo);
- sum_vec = _mm256_add_epi32(sum_vec, mul_hi);
- }
- int s = reduce_add_i32x8(sum_vec);
- for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j];
- if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale);
- if (s > 255)
- s = 255;
- else if (s < 0)
- s = 0;
- C[i * p + j] = (uint8_t)s;
- }
- }
- return 0;
-}
-
-int matmul_avx2_u8_i8_f32(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, float *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- __m256i sum_vec = _mm256_setzero_si256();
- size_t k = 0;
- for (; k + 31 < n; k += 32) {
- __m256i a_lo = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k]));
- __m256i a_hi = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16]));
- __m256i b_lo = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j]));
- __m256i b_hi = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16]));
- __m256i mul_lo = _mm256_madd_epi16(a_lo, b_lo);
- __m256i mul_hi = _mm256_madd_epi16(a_hi, b_hi);
- sum_vec = _mm256_add_epi32(sum_vec, mul_lo);
- sum_vec = _mm256_add_epi32(sum_vec, mul_hi);
- }
- int s = reduce_add_i32x8(sum_vec);
- for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j];
- if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale);
- C[i * p + j] = (float)s;
- }
- }
- return 0;
-}
-
-int matmul_avx2_u8_i8_f64(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, double *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- __m256i sum_vec = _mm256_setzero_si256();
- size_t k = 0;
- for (; k + 31 < n; k += 32) {
- __m256i a_lo = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k]));
- __m256i a_hi = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16]));
- __m256i b_lo = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j]));
- __m256i b_hi = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16]));
- __m256i mul_lo = _mm256_madd_epi16(a_lo, b_lo);
- __m256i mul_hi = _mm256_madd_epi16(a_hi, b_hi);
- sum_vec = _mm256_add_epi32(sum_vec, mul_lo);
- sum_vec = _mm256_add_epi32(sum_vec, mul_hi);
- }
- int s = reduce_add_i32x8(sum_vec);
- for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j];
- if (scale != 0 && scale != 1) s = (int)((double)s / scale);
- C[i * p + j] = (double)s;
- }
- }
- return 0;
-}
-
-int matmul_avx2_u8_i8_i8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, int8_t *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- __m256i sum_vec = _mm256_setzero_si256();
- size_t k = 0;
- for (; k + 31 < n; k += 32) {
- __m256i a_lo = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k]));
- __m256i a_hi = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16]));
- __m256i b_lo = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j]));
- __m256i b_hi = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16]));
- __m256i mul_lo = _mm256_madd_epi16(a_lo, b_lo);
- __m256i mul_hi = _mm256_madd_epi16(a_hi, b_hi);
- sum_vec = _mm256_add_epi32(sum_vec, mul_lo);
- sum_vec = _mm256_add_epi32(sum_vec, mul_hi);
- }
- int s = reduce_add_i32x8(sum_vec);
- for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j];
- if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale);
- if (s > 127)
- s = 127;
- else if (s < -128)
- s = -128;
- C[i * p + j] = (int8_t)s;
- }
- }
- return 0;
-}
-
-int matmul_avx2_u8_i8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, uint8_t *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- __m256i sum_vec = _mm256_setzero_si256();
- size_t k = 0;
- for (; k + 31 < n; k += 32) {
- __m256i a_lo = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k]));
- __m256i a_hi = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16]));
- __m256i b_lo = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j]));
- __m256i b_hi = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16]));
- __m256i mul_lo = _mm256_madd_epi16(a_lo, b_lo);
- __m256i mul_hi = _mm256_madd_epi16(a_hi, b_hi);
- sum_vec = _mm256_add_epi32(sum_vec, mul_lo);
- sum_vec = _mm256_add_epi32(sum_vec, mul_hi);
- }
- int s = reduce_add_i32x8(sum_vec);
- for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j];
- if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale);
- if (s > 255)
- s = 255;
- else if (s < 0)
- s = 0;
- C[i * p + j] = (uint8_t)s;
- }
- }
- return 0;
-}
-
-int matmul_avx2_u8_u8_f32(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, float *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- __m256i sum_vec = _mm256_setzero_si256();
- size_t k = 0;
- for (; k + 31 < n; k += 32) {
- __m256i a_lo = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k]));
- __m256i a_hi = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16]));
- __m256i b_lo = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j]));
- __m256i b_hi = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16]));
- __m256i mul_lo = _mm256_madd_epi16(a_lo, b_lo);
- __m256i mul_hi = _mm256_madd_epi16(a_hi, b_hi);
- sum_vec = _mm256_add_epi32(sum_vec, mul_lo);
- sum_vec = _mm256_add_epi32(sum_vec, mul_hi);
- }
- int s = reduce_add_i32x8(sum_vec);
- for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j];
- if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale);
- C[i * p + j] = (float)s;
- }
- }
- return 0;
-}
-
-int matmul_avx2_u8_u8_f64(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, double *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- __m256i sum_vec = _mm256_setzero_si256();
- size_t k = 0;
- for (; k + 31 < n; k += 32) {
- __m256i a_lo = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k]));
- __m256i a_hi = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16]));
- __m256i b_lo = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j]));
- __m256i b_hi = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16]));
- __m256i mul_lo = _mm256_madd_epi16(a_lo, b_lo);
- __m256i mul_hi = _mm256_madd_epi16(a_hi, b_hi);
- sum_vec = _mm256_add_epi32(sum_vec, mul_lo);
- sum_vec = _mm256_add_epi32(sum_vec, mul_hi);
- }
- int s = reduce_add_i32x8(sum_vec);
- for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j];
- if (scale != 0 && scale != 1) s = (int)((double)s / scale);
- C[i * p + j] = (double)s;
- }
- }
- return 0;
-}
-
-int matmul_avx2_u8_u8_i8(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, int8_t *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- __m256i sum_vec = _mm256_setzero_si256();
- size_t k = 0;
- for (; k + 31 < n; k += 32) {
- __m256i a_lo = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k]));
- __m256i a_hi = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16]));
- __m256i b_lo = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j]));
- __m256i b_hi = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16]));
- __m256i mul_lo = _mm256_madd_epi16(a_lo, b_lo);
- __m256i mul_hi = _mm256_madd_epi16(a_hi, b_hi);
- sum_vec = _mm256_add_epi32(sum_vec, mul_lo);
- sum_vec = _mm256_add_epi32(sum_vec, mul_hi);
- }
- int s = reduce_add_i32x8(sum_vec);
- for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j];
- if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale);
- if (s > 127)
- s = 127;
- else if (s < -128)
- s = -128;
- C[i * p + j] = (int8_t)s;
- }
- }
- return 0;
-}
-
-int matmul_avx2_u8_u8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, uint8_t *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- __m256i sum_vec = _mm256_setzero_si256();
- size_t k = 0;
- for (; k + 31 < n; k += 32) {
- __m256i a_lo = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k]));
- __m256i a_hi = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16]));
- __m256i b_lo = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j]));
- __m256i b_hi = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16]));
- __m256i mul_lo = _mm256_madd_epi16(a_lo, b_lo);
- __m256i mul_hi = _mm256_madd_epi16(a_hi, b_hi);
- sum_vec = _mm256_add_epi32(sum_vec, mul_lo);
- sum_vec = _mm256_add_epi32(sum_vec, mul_hi);
- }
- int s = reduce_add_i32x8(sum_vec);
- for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j];
- if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale);
- if (s > 255)
- s = 255;
- else if (s < 0)
- s = 0;
- C[i * p + j] = (uint8_t)s;
- }
- }
- return 0;
-}
-#endif
-
-#ifdef __AVX512F__
-static inline int32_t reduce_add_i32x16(__m512i v) {
- __m256i low = _mm512_extracti64x4_epi64(v, 0);
- __m256i high = _mm512_extracti64x4_epi64(v, 1);
- __m256i sum = _mm256_add_epi32(low, high);
- sum = _mm256_hadd_epi32(sum, sum);
- sum = _mm256_hadd_epi32(sum, sum);
- __m128i s128 = _mm256_extracti128_si256(sum, 0);
- s128 = _mm_hadd_epi32(s128, s128);
- s128 = _mm_hadd_epi32(s128, s128);
- return _mm_cvtsi128_si32(s128);
-}
-
-int matmul_avx512_i8_i8_f32(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, float *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- __m512i sum_vec = _mm512_setzero_si512();
- size_t k = 0;
- for (; k + 63 < n; k += 64) {
- __m512i a0 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k]));
- __m512i a1 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16]));
- __m512i a2 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 32]));
- __m512i a3 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 48]));
- __m512i b0 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j]));
- __m512i b1 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16]));
- __m512i b2 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 32]));
- __m512i b3 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 48]));
- sum_vec = _mm512_dpwssd_epi32(sum_vec, a0, b0);
- sum_vec = _mm512_dpwssd_epi32(sum_vec, a1, b1);
- sum_vec = _mm512_dpwssd_epi32(sum_vec, a2, b2);
- sum_vec = _mm512_dpwssd_epi32(sum_vec, a3, b3);
- }
- int s = reduce_add_i32x16(sum_vec);
- for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j];
- if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale);
- C[i * p + j] = (float)s;
- }
- }
- return 0;
-}
-
-int matmul_avx512_i8_i8_f64(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, double *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- __m512i sum_vec = _mm512_setzero_si512();
- size_t k = 0;
- for (; k + 63 < n; k += 64) {
- __m512i a0 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k]));
- __m512i a1 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16]));
- __m512i a2 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 32]));
- __m512i a3 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 48]));
- __m512i b0 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j]));
- __m512i b1 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16]));
- __m512i b2 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 32]));
- __m512i b3 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 48]));
- sum_vec = _mm512_dpwssd_epi32(sum_vec, a0, b0);
- sum_vec = _mm512_dpwssd_epi32(sum_vec, a1, b1);
- sum_vec = _mm512_dpwssd_epi32(sum_vec, a2, b2);
- sum_vec = _mm512_dpwssd_epi32(sum_vec, a3, b3);
- }
- int s = reduce_add_i32x16(sum_vec);
- for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j];
- if (scale != 0 && scale != 1) s = (int)((double)s / scale);
- C[i * p + j] = (double)s;
- }
- }
- return 0;
-}
-
-int matmul_avx512_i8_i8_i8(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, int8_t *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- __m512i sum_vec = _mm512_setzero_si512();
- size_t k = 0;
- for (; k + 63 < n; k += 64) {
- __m512i a0 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k]));
- __m512i a1 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16]));
- __m512i a2 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 32]));
- __m512i a3 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 48]));
- __m512i b0 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j]));
- __m512i b1 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16]));
- __m512i b2 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 32]));
- __m512i b3 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 48]));
- sum_vec = _mm512_dpwssd_epi32(sum_vec, a0, b0);
- sum_vec = _mm512_dpwssd_epi32(sum_vec, a1, b1);
- sum_vec = _mm512_dpwssd_epi32(sum_vec, a2, b2);
- sum_vec = _mm512_dpwssd_epi32(sum_vec, a3, b3);
- }
- int s = reduce_add_i32x16(sum_vec);
- for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j];
- if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale);
- if (s > 127)
- s = 127;
- else if (s < -128)
- s = -128;
- C[i * p + j] = (int8_t)s;
- }
- }
- return 0;
-}
-
-int matmul_avx512_i8_i8_u8(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, uint8_t *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- __m512i sum_vec = _mm512_setzero_si512();
- size_t k = 0;
- for (; k + 63 < n; k += 64) {
- __m512i a0 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k]));
- __m512i a1 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16]));
- __m512i a2 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 32]));
- __m512i a3 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 48]));
- __m512i b0 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j]));
- __m512i b1 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16]));
- __m512i b2 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 32]));
- __m512i b3 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 48]));
- sum_vec = _mm512_dpwssd_epi32(sum_vec, a0, b0);
- sum_vec = _mm512_dpwssd_epi32(sum_vec, a1, b1);
- sum_vec = _mm512_dpwssd_epi32(sum_vec, a2, b2);
- sum_vec = _mm512_dpwssd_epi32(sum_vec, a3, b3);
- }
- int s = reduce_add_i32x16(sum_vec);
- for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j];
- if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale);
- if (s > 255)
- s = 255;
- else if (s < 0)
- s = 0;
- C[i * p + j] = (uint8_t)s;
- }
- }
- return 0;
-}
-
-int matmul_avx512_i8_u8_f32(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, float *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- __m512i sum_vec = _mm512_setzero_si512();
- size_t k = 0;
- for (; k + 63 < n; k += 64) {
- __m512i a0 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k]));
- __m512i a1 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16]));
- __m512i a2 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 32]));
- __m512i a3 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 48]));
- __m512i b0 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j]));
- __m512i b1 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16]));
- __m512i b2 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 32]));
- __m512i b3 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 48]));
- sum_vec = _mm512_dpwssd_epi32(sum_vec, a0, b0);
- sum_vec = _mm512_dpwssd_epi32(sum_vec, a1, b1);
- sum_vec = _mm512_dpwssd_epi32(sum_vec, a2, b2);
- sum_vec = _mm512_dpwssd_epi32(sum_vec, a3, b3);
- }
- int s = reduce_add_i32x16(sum_vec);
- for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j];
- if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale);
- C[i * p + j] = (float)s;
- }
- }
- return 0;
-}
-
-int matmul_avx512_i8_u8_f64(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, double *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- __m512i sum_vec = _mm512_setzero_si512();
- size_t k = 0;
- for (; k + 63 < n; k += 64) {
- __m512i a0 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k]));
- __m512i a1 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16]));
- __m512i a2 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 32]));
- __m512i a3 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 48]));
- __m512i b0 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j]));
- __m512i b1 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16]));
- __m512i b2 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 32]));
- __m512i b3 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 48]));
- sum_vec = _mm512_dpwssd_epi32(sum_vec, a0, b0);
- sum_vec = _mm512_dpwssd_epi32(sum_vec, a1, b1);
- sum_vec = _mm512_dpwssd_epi32(sum_vec, a2, b2);
- sum_vec = _mm512_dpwssd_epi32(sum_vec, a3, b3);
- }
- int s = reduce_add_i32x16(sum_vec);
- for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j];
- if (scale != 0 && scale != 1) s = (int)((double)s / scale);
- C[i * p + j] = (double)s;
- }
- }
- return 0;
-}
-
-int matmul_avx512_i8_u8_i8(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, int8_t *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- __m512i sum_vec = _mm512_setzero_si512();
- size_t k = 0;
- for (; k + 63 < n; k += 64) {
- __m512i a0 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k]));
- __m512i a1 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16]));
- __m512i a2 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 32]));
- __m512i a3 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 48]));
- __m512i b0 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j]));
- __m512i b1 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16]));
- __m512i b2 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 32]));
- __m512i b3 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 48]));
- sum_vec = _mm512_dpwssd_epi32(sum_vec, a0, b0);
- sum_vec = _mm512_dpwssd_epi32(sum_vec, a1, b1);
- sum_vec = _mm512_dpwssd_epi32(sum_vec, a2, b2);
- sum_vec = _mm512_dpwssd_epi32(sum_vec, a3, b3);
- }
- int s = reduce_add_i32x16(sum_vec);
- for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j];
- if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale);
- if (s > 127)
- s = 127;
- else if (s < -128)
- s = -128;
- C[i * p + j] = (int8_t)s;
- }
- }
- return 0;
-}
-
-int matmul_avx512_i8_u8_u8(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, uint8_t *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- __m512i sum_vec = _mm512_setzero_si512();
- size_t k = 0;
- for (; k + 63 < n; k += 64) {
- __m512i a0 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k]));
- __m512i a1 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16]));
- __m512i a2 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 32]));
- __m512i a3 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 48]));
- __m512i b0 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j]));
- __m512i b1 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16]));
- __m512i b2 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 32]));
- __m512i b3 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 48]));
- sum_vec = _mm512_dpwssd_epi32(sum_vec, a0, b0);
- sum_vec = _mm512_dpwssd_epi32(sum_vec, a1, b1);
- sum_vec = _mm512_dpwssd_epi32(sum_vec, a2, b2);
- sum_vec = _mm512_dpwssd_epi32(sum_vec, a3, b3);
- }
- int s = reduce_add_i32x16(sum_vec);
- for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j];
- if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale);
- if (s > 255)
- s = 255;
- else if (s < 0)
- s = 0;
- C[i * p + j] = (uint8_t)s;
- }
- }
- return 0;
-}
-
-int matmul_avx512_u8_i8_f32(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, float *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- __m512i sum_vec = _mm512_setzero_si512();
- size_t k = 0;
- for (; k + 63 < n; k += 64) {
- __m512i a0 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k]));
- __m512i a1 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16]));
- __m512i a2 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 32]));
- __m512i a3 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 48]));
- __m512i b0 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j]));
- __m512i b1 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16]));
- __m512i b2 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 32]));
- __m512i b3 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 48]));
- sum_vec = _mm512_dpwssd_epi32(sum_vec, a0, b0);
- sum_vec = _mm512_dpwssd_epi32(sum_vec, a1, b1);
- sum_vec = _mm512_dpwssd_epi32(sum_vec, a2, b2);
- sum_vec = _mm512_dpwssd_epi32(sum_vec, a3, b3);
- }
- int s = reduce_add_i32x16(sum_vec);
- for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j];
- if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale);
- C[i * p + j] = (float)s;
- }
- }
- return 0;
-}
-
-int matmul_avx512_u8_i8_f64(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, double *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- __m512i sum_vec = _mm512_setzero_si512();
- size_t k = 0;
- for (; k + 63 < n; k += 64) {
- __m512i a0 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k]));
- __m512i a1 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16]));
- __m512i a2 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 32]));
- __m512i a3 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 48]));
- __m512i b0 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j]));
- __m512i b1 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16]));
- __m512i b2 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 32]));
- __m512i b3 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 48]));
- sum_vec = _mm512_dpwssd_epi32(sum_vec, a0, b0);
- sum_vec = _mm512_dpwssd_epi32(sum_vec, a1, b1);
- sum_vec = _mm512_dpwssd_epi32(sum_vec, a2, b2);
- sum_vec = _mm512_dpwssd_epi32(sum_vec, a3, b3);
- }
- int s = reduce_add_i32x16(sum_vec);
- for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j];
- if (scale != 0 && scale != 1) s = (int)((double)s / scale);
- C[i * p + j] = (double)s;
- }
- }
- return 0;
-}
-
-int matmul_avx512_u8_i8_i8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, int8_t *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- __m512i sum_vec = _mm512_setzero_si512();
- size_t k = 0;
- for (; k + 63 < n; k += 64) {
- __m512i a0 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k]));
- __m512i a1 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16]));
- __m512i a2 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 32]));
- __m512i a3 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 48]));
- __m512i b0 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j]));
- __m512i b1 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16]));
- __m512i b2 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 32]));
- __m512i b3 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 48]));
- sum_vec = _mm512_dpwssd_epi32(sum_vec, a0, b0);
- sum_vec = _mm512_dpwssd_epi32(sum_vec, a1, b1);
- sum_vec = _mm512_dpwssd_epi32(sum_vec, a2, b2);
- sum_vec = _mm512_dpwssd_epi32(sum_vec, a3, b3);
- }
- int s = reduce_add_i32x16(sum_vec);
- for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j];
- if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale);
- if (s > 127)
- s = 127;
- else if (s < -128)
- s = -128;
- C[i * p + j] = (int8_t)s;
- }
- }
- return 0;
-}
-
-int matmul_avx512_u8_i8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, uint8_t *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- __m512i sum_vec = _mm512_setzero_si512();
- size_t k = 0;
- for (; k + 63 < n; k += 64) {
- __m512i a0 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k]));
- __m512i a1 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16]));
- __m512i a2 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 32]));
- __m512i a3 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 48]));
- __m512i b0 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j]));
- __m512i b1 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16]));
- __m512i b2 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 32]));
- __m512i b3 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 48]));
- sum_vec = _mm512_dpwssd_epi32(sum_vec, a0, b0);
- sum_vec = _mm512_dpwssd_epi32(sum_vec, a1, b1);
- sum_vec = _mm512_dpwssd_epi32(sum_vec, a2, b2);
- sum_vec = _mm512_dpwssd_epi32(sum_vec, a3, b3);
- }
- int s = reduce_add_i32x16(sum_vec);
- for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j];
- if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale);
- if (s > 255)
- s = 255;
- else if (s < 0)
- s = 0;
- C[i * p + j] = (uint8_t)s;
- }
- }
- return 0;
-}
-
-int matmul_avx512_u8_u8_f32(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, float *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- __m512i sum_vec = _mm512_setzero_si512();
- size_t k = 0;
- for (; k + 63 < n; k += 64) {
- __m512i a0 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k]));
- __m512i a1 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16]));
- __m512i a2 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 32]));
- __m512i a3 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 48]));
- __m512i b0 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j]));
- __m512i b1 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16]));
- __m512i b2 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 32]));
- __m512i b3 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 48]));
- sum_vec = _mm512_dpwssd_epi32(sum_vec, a0, b0);
- sum_vec = _mm512_dpwssd_epi32(sum_vec, a1, b1);
- sum_vec = _mm512_dpwssd_epi32(sum_vec, a2, b2);
- sum_vec = _mm512_dpwssd_epi32(sum_vec, a3, b3);
- }
- int s = reduce_add_i32x16(sum_vec);
- for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j];
- if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale);
- C[i * p + j] = (float)s;
- }
- }
- return 0;
-}
-
-int matmul_avx512_u8_u8_f64(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, double *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- __m512i sum_vec = _mm512_setzero_si512();
- size_t k = 0;
- for (; k + 63 < n; k += 64) {
- __m512i a0 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k]));
- __m512i a1 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16]));
- __m512i a2 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 32]));
- __m512i a3 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 48]));
- __m512i b0 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j]));
- __m512i b1 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16]));
- __m512i b2 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 32]));
- __m512i b3 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 48]));
- sum_vec = _mm512_dpwssd_epi32(sum_vec, a0, b0);
- sum_vec = _mm512_dpwssd_epi32(sum_vec, a1, b1);
- sum_vec = _mm512_dpwssd_epi32(sum_vec, a2, b2);
- sum_vec = _mm512_dpwssd_epi32(sum_vec, a3, b3);
- }
- int s = reduce_add_i32x16(sum_vec);
- for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j];
- if (scale != 0 && scale != 1) s = (int)((double)s / scale);
- C[i * p + j] = (double)s;
- }
- }
- return 0;
-}
-
-int matmul_avx512_u8_u8_i8(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, int8_t *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- __m512i sum_vec = _mm512_setzero_si512();
- size_t k = 0;
- for (; k + 63 < n; k += 64) {
- __m512i a0 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k]));
- __m512i a1 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16]));
- __m512i a2 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 32]));
- __m512i a3 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 48]));
- __m512i b0 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j]));
- __m512i b1 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16]));
- __m512i b2 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 32]));
- __m512i b3 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 48]));
- sum_vec = _mm512_dpwssd_epi32(sum_vec, a0, b0);
- sum_vec = _mm512_dpwssd_epi32(sum_vec, a1, b1);
- sum_vec = _mm512_dpwssd_epi32(sum_vec, a2, b2);
- sum_vec = _mm512_dpwssd_epi32(sum_vec, a3, b3);
- }
- int s = reduce_add_i32x16(sum_vec);
- for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j];
- if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale);
- if (s > 127)
- s = 127;
- else if (s < -128)
- s = -128;
- C[i * p + j] = (int8_t)s;
- }
- }
- return 0;
-}
-
-int matmul_avx512_u8_u8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, uint8_t *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- __m512i sum_vec = _mm512_setzero_si512();
- size_t k = 0;
- for (; k + 63 < n; k += 64) {
- __m512i a0 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k]));
- __m512i a1 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16]));
- __m512i a2 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 32]));
- __m512i a3 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 48]));
- __m512i b0 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j]));
- __m512i b1 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16]));
- __m512i b2 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 32]));
- __m512i b3 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 48]));
- sum_vec = _mm512_dpwssd_epi32(sum_vec, a0, b0);
- sum_vec = _mm512_dpwssd_epi32(sum_vec, a1, b1);
- sum_vec = _mm512_dpwssd_epi32(sum_vec, a2, b2);
- sum_vec = _mm512_dpwssd_epi32(sum_vec, a3, b3);
- }
- int s = reduce_add_i32x16(sum_vec);
- for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j];
- if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale);
- if (s > 255)
- s = 255;
- else if (s < 0)
- s = 0;
- C[i * p + j] = (uint8_t)s;
- }
- }
- return 0;
-}
-#endif
-
-int matmul_scalar_u8_u8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, uint8_t *C, double scale) {
-#ifdef _OPENMP
-#pragma omp parallel for schedule(static)
-#endif
- for (size_t i = 0; i < m; i++) {
- for (size_t j = 0; j < p; j++) {
- int sum = 0;
- for (size_t k = 0; k < n; k++) {
- sum += (int)A[i * n + k] * (int)B[k * p + j];
- }
- if (scale != 0 && scale != 1) sum = (int)(sum / scale);
- if (sum > 255)
- sum = 255;
- else if (sum < 0)
- sum = 0;
- C[i * p + j] = (uint8_t)sum;
- }
- }
- return 0;
-}
-
-static int _matmul_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const float *B, float *C, double scale) {
- static int initialized = 0;
- if (!initialized) {
- matmul_feature_t feat = matmul_get_feature();
-#ifdef __AVX512F__
- if (feat & MATMUL_FLAG_AVX512)
- matmul_f32_f32_f32 = matmul_avx512_f32_f32_f32;
- else
-#endif
-#ifdef __AVX2__
- if (feat & MATMUL_FLAG_AVX2)
- matmul_f32_f32_f32 = matmul_avx2_f32_f32_f32;
- else
-#endif
- matmul_f32_f32_f32 = matmul_scalar_f32_f32_f32;
- initialized = 1;
- }
- return matmul_f32_f32_f32(m, n, p, A, B, C, scale);
-}
-
-static int _matmul_f32_f32_f64(size_t m, size_t n, size_t p, const float *A, const float *B, double *C, double scale) {
- static int initialized = 0;
- if (!initialized) {
- matmul_feature_t feat = matmul_get_feature();
-#ifdef __AVX512F__
- if (feat & MATMUL_FLAG_AVX512)
- matmul_f32_f32_f64 = matmul_avx512_f32_f32_f64;
- else
-#endif
-#ifdef __AVX2__
- if (feat & MATMUL_FLAG_AVX2)
- matmul_f32_f32_f64 = matmul_avx2_f32_f32_f64;
- else
-#endif
- matmul_f32_f32_f64 = matmul_scalar_f32_f32_f64;
- initialized = 1;
- }
- return matmul_f32_f32_f64(m, n, p, A, B, C, scale);
-}
-
-static int _matmul_f32_f32_i8(size_t m, size_t n, size_t p, const float *A, const float *B, int8_t *C, double scale) {
- static int initialized = 0;
- if (!initialized) {
- matmul_f32_f32_i8 = matmul_scalar_f32_f32_i8;
- initialized = 1;
- }
- return matmul_f32_f32_i8(m, n, p, A, B, C, scale);
-}
-
-static int _matmul_f32_f32_u8(size_t m, size_t n, size_t p, const float *A, const float *B, uint8_t *C, double scale) {
- static int initialized = 0;
- if (!initialized) {
- matmul_f32_f32_u8 = matmul_scalar_f32_f32_u8;
- initialized = 1;
- }
- return matmul_f32_f32_u8(m, n, p, A, B, C, scale);
-}
-
-static int _matmul_f32_f64_f32(size_t m, size_t n, size_t p, const float *A, const double *B, float *C, double scale) {
- static int initialized = 0;
- if (!initialized) {
- matmul_feature_t feat = matmul_get_feature();
-#ifdef __AVX512F__
- if (feat & MATMUL_FLAG_AVX512)
- matmul_f32_f64_f32 = matmul_avx512_f32_f64_f32;
- else
-#endif
-#ifdef __AVX2__
- if (feat & MATMUL_FLAG_AVX2)
- matmul_f32_f64_f32 = matmul_avx2_f32_f64_f32;
- else
-#endif
- matmul_f32_f64_f32 = matmul_scalar_f32_f64_f32;
- initialized = 1;
- }
- return matmul_f32_f64_f32(m, n, p, A, B, C, scale);
-}
-
-static int _matmul_f32_f64_f64(size_t m, size_t n, size_t p, const float *A, const double *B, double *C, double scale) {
- static int initialized = 0;
- if (!initialized) {
- matmul_feature_t feat = matmul_get_feature();
-#ifdef __AVX512F__
- if (feat & MATMUL_FLAG_AVX512)
- matmul_f32_f64_f64 = matmul_avx512_f32_f64_f64;
- else
-#endif
-#ifdef __AVX2__
- if (feat & MATMUL_FLAG_AVX2)
- matmul_f32_f64_f64 = matmul_avx2_f32_f64_f64;
- else
-#endif
- matmul_f32_f64_f64 = matmul_scalar_f32_f64_f64;
- initialized = 1;
- }
- return matmul_f32_f64_f64(m, n, p, A, B, C, scale);
-}
-
-static int _matmul_f32_f64_i8(size_t m, size_t n, size_t p, const float *A, const double *B, int8_t *C, double scale) {
- static int initialized = 0;
- if (!initialized) {
- matmul_f32_f64_i8 = matmul_scalar_f32_f64_i8;
- initialized = 1;
- }
- return matmul_f32_f64_i8(m, n, p, A, B, C, scale);
-}
-
-static int _matmul_f32_f64_u8(size_t m, size_t n, size_t p, const float *A, const double *B, uint8_t *C, double scale) {
- static int initialized = 0;
- if (!initialized) {
- matmul_f32_f64_u8 = matmul_scalar_f32_f64_u8;
- initialized = 1;
- }
- return matmul_f32_f64_u8(m, n, p, A, B, C, scale);
-}
-
-static int _matmul_f32_i8_f32(size_t m, size_t n, size_t p, const float *A, const int8_t *B, float *C, double scale) {
- static int initialized = 0;
- if (!initialized) {
- matmul_f32_i8_f32 = matmul_scalar_f32_i8_f32;
- initialized = 1;
- }
- return matmul_f32_i8_f32(m, n, p, A, B, C, scale);
-}
-
-static int _matmul_f32_i8_f64(size_t m, size_t n, size_t p, const float *A, const int8_t *B, double *C, double scale) {
- static int initialized = 0;
- if (!initialized) {
- matmul_f32_i8_f64 = matmul_scalar_f32_i8_f64;
- initialized = 1;
- }
- return matmul_f32_i8_f64(m, n, p, A, B, C, scale);
-}
-
-static int _matmul_f32_i8_i8(size_t m, size_t n, size_t p, const float *A, const int8_t *B, int8_t *C, double scale) {
- static int initialized = 0;
- if (!initialized) {
- matmul_f32_i8_i8 = matmul_scalar_f32_i8_i8;
- initialized = 1;
- }
- return matmul_f32_i8_i8(m, n, p, A, B, C, scale);
-}
-
-static int _matmul_f32_i8_u8(size_t m, size_t n, size_t p, const float *A, const int8_t *B, uint8_t *C, double scale) {
- static int initialized = 0;
- if (!initialized) {
- matmul_f32_i8_u8 = matmul_scalar_f32_i8_u8;
- initialized = 1;
- }
- return matmul_f32_i8_u8(m, n, p, A, B, C, scale);
-}
-
-static int _matmul_f32_u8_f32(size_t m, size_t n, size_t p, const float *A, const uint8_t *B, float *C, double scale) {
- static int initialized = 0;
- if (!initialized) {
- matmul_f32_u8_f32 = matmul_scalar_f32_u8_f32;
- initialized = 1;
- }
- return matmul_f32_u8_f32(m, n, p, A, B, C, scale);
-}
-
-static int _matmul_f32_u8_f64(size_t m, size_t n, size_t p, const float *A, const uint8_t *B, double *C, double scale) {
- static int initialized = 0;
- if (!initialized) {
- matmul_f32_u8_f64 = matmul_scalar_f32_u8_f64;
- initialized = 1;
- }
- return matmul_f32_u8_f64(m, n, p, A, B, C, scale);
-}
-
-static int _matmul_f32_u8_i8(size_t m, size_t n, size_t p, const float *A, const uint8_t *B, int8_t *C, double scale) {
- static int initialized = 0;
- if (!initialized) {
- matmul_f32_u8_i8 = matmul_scalar_f32_u8_i8;
- initialized = 1;
- }
- return matmul_f32_u8_i8(m, n, p, A, B, C, scale);
-}
-
-static int _matmul_f32_u8_u8(size_t m, size_t n, size_t p, const float *A, const uint8_t *B, uint8_t *C, double scale) {
- static int initialized = 0;
- if (!initialized) {
- matmul_f32_u8_u8 = matmul_scalar_f32_u8_u8;
- initialized = 1;
- }
- return matmul_f32_u8_u8(m, n, p, A, B, C, scale);
-}
-
-static int _matmul_f64_f32_f32(size_t m, size_t n, size_t p, const double *A, const float *B, float *C, double scale) {
- static int initialized = 0;
- if (!initialized) {
- matmul_feature_t feat = matmul_get_feature();
-#ifdef __AVX512F__
- if (feat & MATMUL_FLAG_AVX512)
- matmul_f64_f32_f32 = matmul_avx512_f64_f32_f32;
- else
-#endif
-#ifdef __AVX2__
- if (feat & MATMUL_FLAG_AVX2)
- matmul_f64_f32_f32 = matmul_avx2_f64_f32_f32;
- else
-#endif
- matmul_f64_f32_f32 = matmul_scalar_f64_f32_f32;
- initialized = 1;
- }
- return matmul_f64_f32_f32(m, n, p, A, B, C, scale);
-}
-
-static int _matmul_f64_f32_f64(size_t m, size_t n, size_t p, const double *A, const float *B, double *C, double scale) {
- static int initialized = 0;
- if (!initialized) {
- matmul_feature_t feat = matmul_get_feature();
-#ifdef __AVX512F__
- if (feat & MATMUL_FLAG_AVX512)
- matmul_f64_f32_f64 = matmul_avx512_f64_f32_f64;
- else
-#endif
-#ifdef __AVX2__
- if (feat & MATMUL_FLAG_AVX2)
- matmul_f64_f32_f64 = matmul_avx2_f64_f32_f64;
- else
-#endif
- matmul_f64_f32_f64 = matmul_scalar_f64_f32_f64;
- initialized = 1;
- }
- return matmul_f64_f32_f64(m, n, p, A, B, C, scale);
-}
-
-static int _matmul_f64_f32_i8(size_t m, size_t n, size_t p, const double *A, const float *B, int8_t *C, double scale) {
- static int initialized = 0;
- if (!initialized) {
- matmul_f64_f32_i8 = matmul_scalar_f64_f32_i8;
- initialized = 1;
- }
- return matmul_f64_f32_i8(m, n, p, A, B, C, scale);
-}
-
-static int _matmul_f64_f32_u8(size_t m, size_t n, size_t p, const double *A, const float *B, uint8_t *C, double scale) {
- static int initialized = 0;
- if (!initialized) {
- matmul_f64_f32_u8 = matmul_scalar_f64_f32_u8;
- initialized = 1;
- }
- return matmul_f64_f32_u8(m, n, p, A, B, C, scale);
-}
-
-static int _matmul_f64_f64_f32(size_t m, size_t n, size_t p, const double *A, const double *B, float *C, double scale) {
- static int initialized = 0;
- if (!initialized) {
- matmul_feature_t feat = matmul_get_feature();
-#ifdef __AVX512F__
- if (feat & MATMUL_FLAG_AVX512)
- matmul_f64_f64_f32 = matmul_avx512_f64_f64_f32;
- else
-#endif
-#ifdef __AVX2__
- if (feat & MATMUL_FLAG_AVX2)
- matmul_f64_f64_f32 = matmul_avx2_f64_f64_f32;
- else
-#endif
- matmul_f64_f64_f32 = matmul_scalar_f64_f64_f32;
- initialized = 1;
- }
- return matmul_f64_f64_f32(m, n, p, A, B, C, scale);
-}
-
-static int _matmul_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, const double *B, double *C,
- double scale) {
- static int initialized = 0;
- if (!initialized) {
- matmul_feature_t feat = matmul_get_feature();
-#ifdef __AVX512F__
- if (feat & MATMUL_FLAG_AVX512)
- matmul_f64_f64_f64 = matmul_avx512_f64_f64_f64;
- else
-#endif
-#ifdef __AVX2__
- if (feat & MATMUL_FLAG_AVX2)
- matmul_f64_f64_f64 = matmul_avx2_f64_f64_f64;
- else
-#endif
- matmul_f64_f64_f64 = matmul_scalar_f64_f64_f64;
- initialized = 1;
- }
- return matmul_f64_f64_f64(m, n, p, A, B, C, scale);
-}
-
-static int _matmul_f64_f64_i8(size_t m, size_t n, size_t p, const double *A, const double *B, int8_t *C, double scale) {
- static int initialized = 0;
- if (!initialized) {
- matmul_f64_f64_i8 = matmul_scalar_f64_f64_i8;
- initialized = 1;
- }
- return matmul_f64_f64_i8(m, n, p, A, B, C, scale);
-}
-
-static int _matmul_f64_f64_u8(size_t m, size_t n, size_t p, const double *A, const double *B, uint8_t *C,
- double scale) {
- static int initialized = 0;
- if (!initialized) {
- matmul_f64_f64_u8 = matmul_scalar_f64_f64_u8;
- initialized = 1;
- }
- return matmul_f64_f64_u8(m, n, p, A, B, C, scale);
-}
-
-static int _matmul_f64_i8_f32(size_t m, size_t n, size_t p, const double *A, const int8_t *B, float *C, double scale) {
- static int initialized = 0;
- if (!initialized) {
- matmul_f64_i8_f32 = matmul_scalar_f64_i8_f32;
- initialized = 1;
- }
- return matmul_f64_i8_f32(m, n, p, A, B, C, scale);
-}
-
-static int _matmul_f64_i8_f64(size_t m, size_t n, size_t p, const double *A, const int8_t *B, double *C, double scale) {
- static int initialized = 0;
- if (!initialized) {
- matmul_f64_i8_f64 = matmul_scalar_f64_i8_f64;
- initialized = 1;
- }
- return matmul_f64_i8_f64(m, n, p, A, B, C, scale);
-}
-
-static int _matmul_f64_i8_i8(size_t m, size_t n, size_t p, const double *A, const int8_t *B, int8_t *C, double scale) {
- static int initialized = 0;
- if (!initialized) {
- matmul_f64_i8_i8 = matmul_scalar_f64_i8_i8;
- initialized = 1;
- }
- return matmul_f64_i8_i8(m, n, p, A, B, C, scale);
-}
-
-static int _matmul_f64_i8_u8(size_t m, size_t n, size_t p, const double *A, const int8_t *B, uint8_t *C, double scale) {
- static int initialized = 0;
- if (!initialized) {
- matmul_f64_i8_u8 = matmul_scalar_f64_i8_u8;
- initialized = 1;
- }
- return matmul_f64_i8_u8(m, n, p, A, B, C, scale);
-}
-
-static int _matmul_f64_u8_f32(size_t m, size_t n, size_t p, const double *A, const uint8_t *B, float *C, double scale) {
- static int initialized = 0;
- if (!initialized) {
- matmul_f64_u8_f32 = matmul_scalar_f64_u8_f32;
- initialized = 1;
- }
- return matmul_f64_u8_f32(m, n, p, A, B, C, scale);
-}
-
-static int _matmul_f64_u8_f64(size_t m, size_t n, size_t p, const double *A, const uint8_t *B, double *C,
- double scale) {
- static int initialized = 0;
- if (!initialized) {
- matmul_f64_u8_f64 = matmul_scalar_f64_u8_f64;
- initialized = 1;
- }
- return matmul_f64_u8_f64(m, n, p, A, B, C, scale);
-}
-
-static int _matmul_f64_u8_i8(size_t m, size_t n, size_t p, const double *A, const uint8_t *B, int8_t *C, double scale) {
- static int initialized = 0;
- if (!initialized) {
- matmul_f64_u8_i8 = matmul_scalar_f64_u8_i8;
- initialized = 1;
- }
- return matmul_f64_u8_i8(m, n, p, A, B, C, scale);
-}
-
-static int _matmul_f64_u8_u8(size_t m, size_t n, size_t p, const double *A, const uint8_t *B, uint8_t *C,
- double scale) {
- static int initialized = 0;
- if (!initialized) {
- matmul_f64_u8_u8 = matmul_scalar_f64_u8_u8;
- initialized = 1;
- }
- return matmul_f64_u8_u8(m, n, p, A, B, C, scale);
-}
-
-static int _matmul_i8_f32_f32(size_t m, size_t n, size_t p, const int8_t *A, const float *B, float *C, double scale) {
- static int initialized = 0;
- if (!initialized) {
- matmul_i8_f32_f32 = matmul_scalar_i8_f32_f32;
- initialized = 1;
- }
- return matmul_i8_f32_f32(m, n, p, A, B, C, scale);
-}
-
-static int _matmul_i8_f32_f64(size_t m, size_t n, size_t p, const int8_t *A, const float *B, double *C, double scale) {
- static int initialized = 0;
- if (!initialized) {
- matmul_i8_f32_f64 = matmul_scalar_i8_f32_f64;
- initialized = 1;
- }
- return matmul_i8_f32_f64(m, n, p, A, B, C, scale);
-}
-
-static int _matmul_i8_f32_i8(size_t m, size_t n, size_t p, const int8_t *A, const float *B, int8_t *C, double scale) {
- static int initialized = 0;
- if (!initialized) {
- matmul_i8_f32_i8 = matmul_scalar_i8_f32_i8;
- initialized = 1;
- }
- return matmul_i8_f32_i8(m, n, p, A, B, C, scale);
-}
-
-static int _matmul_i8_f32_u8(size_t m, size_t n, size_t p, const int8_t *A, const float *B, uint8_t *C, double scale) {
- static int initialized = 0;
- if (!initialized) {
- matmul_i8_f32_u8 = matmul_scalar_i8_f32_u8;
- initialized = 1;
- }
- return matmul_i8_f32_u8(m, n, p, A, B, C, scale);
-}
-
-static int _matmul_i8_f64_f32(size_t m, size_t n, size_t p, const int8_t *A, const double *B, float *C, double scale) {
- static int initialized = 0;
- if (!initialized) {
- matmul_i8_f64_f32 = matmul_scalar_i8_f64_f32;
- initialized = 1;
- }
- return matmul_i8_f64_f32(m, n, p, A, B, C, scale);
-}
-
-static int _matmul_i8_f64_f64(size_t m, size_t n, size_t p, const int8_t *A, const double *B, double *C, double scale) {
- static int initialized = 0;
- if (!initialized) {
- matmul_i8_f64_f64 = matmul_scalar_i8_f64_f64;
- initialized = 1;
- }
- return matmul_i8_f64_f64(m, n, p, A, B, C, scale);
-}
-
-static int _matmul_i8_f64_i8(size_t m, size_t n, size_t p, const int8_t *A, const double *B, int8_t *C, double scale) {
- static int initialized = 0;
- if (!initialized) {
- matmul_i8_f64_i8 = matmul_scalar_i8_f64_i8;
- initialized = 1;
- }
- return matmul_i8_f64_i8(m, n, p, A, B, C, scale);
-}
-
-static int _matmul_i8_f64_u8(size_t m, size_t n, size_t p, const int8_t *A, const double *B, uint8_t *C, double scale) {
- static int initialized = 0;
- if (!initialized) {
- matmul_i8_f64_u8 = matmul_scalar_i8_f64_u8;
- initialized = 1;
- }
- return matmul_i8_f64_u8(m, n, p, A, B, C, scale);
-}
-
-static int _matmul_i8_i8_f32(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, float *C, double scale) {
- static int initialized = 0;
- if (!initialized) {
- matmul_feature_t feat = matmul_get_feature();
-#ifdef __AVX512F__
- if (feat & MATMUL_FLAG_AVX512_VNNI)
- matmul_i8_i8_f32 = matmul_avx512_i8_i8_f32;
- else
-#endif
-#ifdef __AVX2__
- if (feat & MATMUL_FLAG_AVXVNNI)
- matmul_i8_i8_f32 = matmul_avx2_i8_i8_f32;
- else
-#endif
- matmul_i8_i8_f32 = matmul_scalar_i8_i8_f32;
- initialized = 1;
- }
- return matmul_i8_i8_f32(m, n, p, A, B, C, scale);
-}
-
-static int _matmul_i8_i8_f64(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, double *C, double scale) {
- static int initialized = 0;
- if (!initialized) {
- matmul_feature_t feat = matmul_get_feature();
-#ifdef __AVX512F__
- if (feat & MATMUL_FLAG_AVX512_VNNI)
- matmul_i8_i8_f64 = matmul_avx512_i8_i8_f64;
- else
-#endif
-#ifdef __AVX2__
- if (feat & MATMUL_FLAG_AVXVNNI)
- matmul_i8_i8_f64 = matmul_avx2_i8_i8_f64;
- else
-#endif
- matmul_i8_i8_f64 = matmul_scalar_i8_i8_f64;
- initialized = 1;
- }
- return matmul_i8_i8_f64(m, n, p, A, B, C, scale);
-}
-
-static int _matmul_i8_i8_i8(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, int8_t *C, double scale) {
- static int initialized = 0;
- if (!initialized) {
- matmul_feature_t feat = matmul_get_feature();
-#ifdef __AVX512F__
- if (feat & MATMUL_FLAG_AVX512_VNNI)
- matmul_i8_i8_i8 = matmul_avx512_i8_i8_i8;
- else
-#endif
-#ifdef __AVX2__
- if (feat & MATMUL_FLAG_AVXVNNI)
- matmul_i8_i8_i8 = matmul_avx2_i8_i8_i8;
- else
-#endif
- matmul_i8_i8_i8 = matmul_scalar_i8_i8_i8;
- initialized = 1;
- }
- return matmul_i8_i8_i8(m, n, p, A, B, C, scale);
-}
-
-static int _matmul_i8_i8_u8(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, uint8_t *C, double scale) {
- static int initialized = 0;
- if (!initialized) {
- matmul_feature_t feat = matmul_get_feature();
-#ifdef __AVX512F__
- if (feat & MATMUL_FLAG_AVX512_VNNI)
- matmul_i8_i8_u8 = matmul_avx512_i8_i8_u8;
- else
-#endif
-#ifdef __AVX2__
- if (feat & MATMUL_FLAG_AVXVNNI)
- matmul_i8_i8_u8 = matmul_avx2_i8_i8_u8;
- else
-#endif
- matmul_i8_i8_u8 = matmul_scalar_i8_i8_u8;
- initialized = 1;
- }
- return matmul_i8_i8_u8(m, n, p, A, B, C, scale);
-}
-
-static int _matmul_i8_u8_f32(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, float *C, double scale) {
- static int initialized = 0;
- if (!initialized) {
- matmul_feature_t feat = matmul_get_feature();
-#ifdef __AVX512F__
- if (feat & MATMUL_FLAG_AVX512_VNNI)
- matmul_i8_u8_f32 = matmul_avx512_i8_u8_f32;
- else
-#endif
-#ifdef __AVX2__
- if (feat & MATMUL_FLAG_AVXVNNI)
- matmul_i8_u8_f32 = matmul_avx2_i8_u8_f32;
- else
-#endif
- matmul_i8_u8_f32 = matmul_scalar_i8_u8_f32;
- initialized = 1;
- }
- return matmul_i8_u8_f32(m, n, p, A, B, C, scale);
-}
-
-static int _matmul_i8_u8_f64(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, double *C, double scale) {
- static int initialized = 0;
- if (!initialized) {
- matmul_feature_t feat = matmul_get_feature();
-#ifdef __AVX512F__
- if (feat & MATMUL_FLAG_AVX512_VNNI)
- matmul_i8_u8_f64 = matmul_avx512_i8_u8_f64;
- else
-#endif
-#ifdef __AVX2__
- if (feat & MATMUL_FLAG_AVXVNNI)
- matmul_i8_u8_f64 = matmul_avx2_i8_u8_f64;
- else
-#endif
- matmul_i8_u8_f64 = matmul_scalar_i8_u8_f64;
- initialized = 1;
- }
- return matmul_i8_u8_f64(m, n, p, A, B, C, scale);
-}
-
-static int _matmul_i8_u8_i8(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, int8_t *C, double scale) {
- static int initialized = 0;
- if (!initialized) {
- matmul_feature_t feat = matmul_get_feature();
-#ifdef __AVX512F__
- if (feat & MATMUL_FLAG_AVX512_VNNI)
- matmul_i8_u8_i8 = matmul_avx512_i8_u8_i8;
- else
-#endif
-#ifdef __AVX2__
- if (feat & MATMUL_FLAG_AVXVNNI)
- matmul_i8_u8_i8 = matmul_avx2_i8_u8_i8;
- else
-#endif
- matmul_i8_u8_i8 = matmul_scalar_i8_u8_i8;
- initialized = 1;
- }
- return matmul_i8_u8_i8(m, n, p, A, B, C, scale);
-}
-
-static int _matmul_i8_u8_u8(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, uint8_t *C, double scale) {
- static int initialized = 0;
- if (!initialized) {
- matmul_feature_t feat = matmul_get_feature();
-#ifdef __AVX512F__
- if (feat & MATMUL_FLAG_AVX512_VNNI)
- matmul_i8_u8_u8 = matmul_avx512_i8_u8_u8;
- else
-#endif
-#ifdef __AVX2__
- if (feat & MATMUL_FLAG_AVXVNNI)
- matmul_i8_u8_u8 = matmul_avx2_i8_u8_u8;
- else
-#endif
- matmul_i8_u8_u8 = matmul_scalar_i8_u8_u8;
- initialized = 1;
- }
- return matmul_i8_u8_u8(m, n, p, A, B, C, scale);
-}
-
-static int _matmul_u8_i8_f32(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, float *C, double scale) {
- static int initialized = 0;
- if (!initialized) {
- matmul_feature_t feat = matmul_get_feature();
-#ifdef __AVX512F__
- if (feat & MATMUL_FLAG_AVX512_VNNI)
- matmul_u8_i8_f32 = matmul_avx512_u8_i8_f32;
- else
-#endif
-#ifdef __AVX2__
- if (feat & MATMUL_FLAG_AVXVNNI)
- matmul_u8_i8_f32 = matmul_avx2_u8_i8_f32;
- else
-#endif
- matmul_u8_i8_f32 = matmul_scalar_u8_i8_f32;
- initialized = 1;
- }
- return matmul_u8_i8_f32(m, n, p, A, B, C, scale);
-}
-
-static int _matmul_u8_i8_f64(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, double *C, double scale) {
- static int initialized = 0;
- if (!initialized) {
- matmul_feature_t feat = matmul_get_feature();
-#ifdef __AVX512F__
- if (feat & MATMUL_FLAG_AVX512_VNNI)
- matmul_u8_i8_f64 = matmul_avx512_u8_i8_f64;
- else
-#endif
-#ifdef __AVX2__
- if (feat & MATMUL_FLAG_AVXVNNI)
- matmul_u8_i8_f64 = matmul_avx2_u8_i8_f64;
- else
-#endif
- matmul_u8_i8_f64 = matmul_scalar_u8_i8_f64;
- initialized = 1;
- }
- return matmul_u8_i8_f64(m, n, p, A, B, C, scale);
-}
-
-static int _matmul_u8_i8_i8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, int8_t *C, double scale) {
- static int initialized = 0;
- if (!initialized) {
- matmul_feature_t feat = matmul_get_feature();
-#ifdef __AVX512F__
- if (feat & MATMUL_FLAG_AVX512_VNNI)
- matmul_u8_i8_i8 = matmul_avx512_u8_i8_i8;
- else
-#endif
-#ifdef __AVX2__
- if (feat & MATMUL_FLAG_AVXVNNI)
- matmul_u8_i8_i8 = matmul_avx2_u8_i8_i8;
- else
-#endif
- matmul_u8_i8_i8 = matmul_scalar_u8_i8_i8;
- initialized = 1;
- }
- return matmul_u8_i8_i8(m, n, p, A, B, C, scale);
-}
-
-static int _matmul_u8_i8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, uint8_t *C, double scale) {
- static int initialized = 0;
- if (!initialized) {
- matmul_feature_t feat = matmul_get_feature();
-#ifdef __AVX512F__
- if (feat & MATMUL_FLAG_AVX512_VNNI)
- matmul_u8_i8_u8 = matmul_avx512_u8_i8_u8;
- else
-#endif
-#ifdef __AVX2__
- if (feat & MATMUL_FLAG_AVXVNNI)
- matmul_u8_i8_u8 = matmul_avx2_u8_i8_u8;
- else
-#endif
- matmul_u8_i8_u8 = matmul_scalar_u8_i8_u8;
- initialized = 1;
+ matmul_u8_i8_u8 = matmul_scalar_u8_i8_u8;
+ initialized = 1;
}
return matmul_u8_i8_u8(m, n, p, A, B, C, scale);
}
-static int _matmul_u8_u8_f32(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, float *C, double scale) {
- static int initialized = 0;
- if (!initialized) {
- matmul_feature_t feat = matmul_get_feature();
-#ifdef __AVX512F__
- if (feat & MATMUL_FLAG_AVX512_VNNI)
- matmul_u8_u8_f32 = matmul_avx512_u8_u8_f32;
- else
-#endif
-#ifdef __AVX2__
- if (feat & MATMUL_FLAG_AVXVNNI)
- matmul_u8_u8_f32 = matmul_avx2_u8_u8_f32;
- else
-#endif
- matmul_u8_u8_f32 = matmul_scalar_u8_u8_f32;
- initialized = 1;
- }
- return matmul_u8_u8_f32(m, n, p, A, B, C, scale);
-}
-
-static int _matmul_u8_u8_f64(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, double *C,
- double scale) {
- static int initialized = 0;
- if (!initialized) {
- matmul_feature_t feat = matmul_get_feature();
-#ifdef __AVX512F__
- if (feat & MATMUL_FLAG_AVX512_VNNI)
- matmul_u8_u8_f64 = matmul_avx512_u8_u8_f64;
- else
-#endif
-#ifdef __AVX2__
- if (feat & MATMUL_FLAG_AVXVNNI)
- matmul_u8_u8_f64 = matmul_avx2_u8_u8_f64;
- else
-#endif
- matmul_u8_u8_f64 = matmul_scalar_u8_u8_f64;
- initialized = 1;
- }
- return matmul_u8_u8_f64(m, n, p, A, B, C, scale);
-}
-
-static int _matmul_u8_u8_i8(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, int8_t *C, double scale) {
- static int initialized = 0;
- if (!initialized) {
- matmul_feature_t feat = matmul_get_feature();
-#ifdef __AVX512F__
- if (feat & MATMUL_FLAG_AVX512_VNNI)
- matmul_u8_u8_i8 = matmul_avx512_u8_u8_i8;
- else
-#endif
-#ifdef __AVX2__
- if (feat & MATMUL_FLAG_AVXVNNI)
- matmul_u8_u8_i8 = matmul_avx2_u8_u8_i8;
- else
-#endif
- matmul_u8_u8_i8 = matmul_scalar_u8_u8_i8;
- initialized = 1;
- }
- return matmul_u8_u8_i8(m, n, p, A, B, C, scale);
-}
-
-static int _matmul_u8_u8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, uint8_t *C,
- double scale) {
- static int initialized = 0;
- if (!initialized) {
- matmul_feature_t feat = matmul_get_feature();
-#ifdef __AVX512F__
- if (feat & MATMUL_FLAG_AVX512_VNNI)
- matmul_u8_u8_u8 = matmul_avx512_u8_u8_u8;
- else
-#endif
-#ifdef __AVX2__
- if (feat & MATMUL_FLAG_AVXVNNI)
- matmul_u8_u8_u8 = matmul_avx2_u8_u8_u8;
- else
-#endif
- matmul_u8_u8_u8 = matmul_scalar_u8_u8_u8;
- initialized = 1;
- }
- return matmul_u8_u8_u8(m, n, p, A, B, C, scale);
-}
-
-static int _matmul_u8_f32_f32(size_t m, size_t n, size_t p, const uint8_t *A, const float *B, float *C, double scale) {
- static int initialized = 0;
- if (!initialized) {
- matmul_u8_f32_f32 = matmul_scalar_u8_f32_f32;
- initialized = 1;
- }
- return matmul_u8_f32_f32(m, n, p, A, B, C, scale);
-}
-
-static int _matmul_u8_f32_f64(size_t m, size_t n, size_t p, const uint8_t *A, const float *B, double *C, double scale) {
- static int initialized = 0;
- if (!initialized) {
- matmul_u8_f32_f64 = matmul_scalar_u8_f32_f64;
- initialized = 1;
- }
- return matmul_u8_f32_f64(m, n, p, A, B, C, scale);
-}
-
-static int _matmul_u8_f32_i8(size_t m, size_t n, size_t p, const uint8_t *A, const float *B, int8_t *C, double scale) {
- static int initialized = 0;
- if (!initialized) {
- matmul_u8_f32_i8 = matmul_scalar_u8_f32_i8;
- initialized = 1;
- }
- return matmul_u8_f32_i8(m, n, p, A, B, C, scale);
-}
-
-static int _matmul_u8_f32_u8(size_t m, size_t n, size_t p, const uint8_t *A, const float *B, uint8_t *C, double scale) {
- static int initialized = 0;
- if (!initialized) {
- matmul_u8_f32_u8 = matmul_scalar_u8_f32_u8;
- initialized = 1;
- }
- return matmul_u8_f32_u8(m, n, p, A, B, C, scale);
-}
-
-static int _matmul_u8_f64_f32(size_t m, size_t n, size_t p, const uint8_t *A, const double *B, float *C, double scale) {
- static int initialized = 0;
- if (!initialized) {
- matmul_u8_f64_f32 = matmul_scalar_u8_f64_f32;
- initialized = 1;
- }
- return matmul_u8_f64_f32(m, n, p, A, B, C, scale);
-}
-
-static int _matmul_u8_f64_f64(size_t m, size_t n, size_t p, const uint8_t *A, const double *B, double *C,
- double scale) {
- static int initialized = 0;
- if (!initialized) {
- matmul_u8_f64_f64 = matmul_scalar_u8_f64_f64;
- initialized = 1;
- }
- return matmul_u8_f64_f64(m, n, p, A, B, C, scale);
-}
-
-static int _matmul_u8_f64_i8(size_t m, size_t n, size_t p, const uint8_t *A, const double *B, int8_t *C, double scale) {
- static int initialized = 0;
- if (!initialized) {
- matmul_u8_f64_i8 = matmul_scalar_u8_f64_i8;
- initialized = 1;
- }
- return matmul_u8_f64_i8(m, n, p, A, B, C, scale);
-}
-
-static int _matmul_u8_f64_u8(size_t m, size_t n, size_t p, const uint8_t *A, const double *B, uint8_t *C,
- double scale) {
- static int initialized = 0;
- if (!initialized) {
- matmul_u8_f64_u8 = matmul_scalar_u8_f64_u8;
- initialized = 1;
- }
- return matmul_u8_f64_u8(m, n, p, A, B, C, scale);
-}
-
-int (*matmul_f32_f32_f32)(size_t, size_t, size_t, const float *, const float *, float *, double) = _matmul_f32_f32_f32;
-int (*matmul_f32_f32_f64)(size_t, size_t, size_t, const float *, const float *, double *, double) = _matmul_f32_f32_f64;
-int (*matmul_f32_f32_i8)(size_t, size_t, size_t, const float *, const float *, int8_t *, double) = _matmul_f32_f32_i8;
-int (*matmul_f32_f32_u8)(size_t, size_t, size_t, const float *, const float *, uint8_t *, double) = _matmul_f32_f32_u8;
-int (*matmul_f32_f64_f32)(size_t, size_t, size_t, const float *, const double *, float *, double) = _matmul_f32_f64_f32;
-int (*matmul_f32_f64_f64)(size_t, size_t, size_t, const float *, const double *, double *,
- double) = _matmul_f32_f64_f64;
-int (*matmul_f32_f64_i8)(size_t, size_t, size_t, const float *, const double *, int8_t *, double) = _matmul_f32_f64_i8;
-int (*matmul_f32_f64_u8)(size_t, size_t, size_t, const float *, const double *, uint8_t *, double) = _matmul_f32_f64_u8;
-int (*matmul_f32_i8_f32)(size_t, size_t, size_t, const float *, const int8_t *, float *, double) = _matmul_f32_i8_f32;
-int (*matmul_f32_i8_f64)(size_t, size_t, size_t, const float *, const int8_t *, double *, double) = _matmul_f32_i8_f64;
-int (*matmul_f32_i8_i8)(size_t, size_t, size_t, const float *, const int8_t *, int8_t *, double) = _matmul_f32_i8_i8;
-int (*matmul_f32_i8_u8)(size_t, size_t, size_t, const float *, const int8_t *, uint8_t *, double) = _matmul_f32_i8_u8;
-int (*matmul_f32_u8_f32)(size_t, size_t, size_t, const float *, const uint8_t *, float *, double) = _matmul_f32_u8_f32;
-int (*matmul_f32_u8_f64)(size_t, size_t, size_t, const float *, const uint8_t *, double *, double) = _matmul_f32_u8_f64;
-int (*matmul_f32_u8_i8)(size_t, size_t, size_t, const float *, const uint8_t *, int8_t *, double) = _matmul_f32_u8_i8;
-int (*matmul_f32_u8_u8)(size_t, size_t, size_t, const float *, const uint8_t *, uint8_t *, double) = _matmul_f32_u8_u8;
-int (*matmul_f64_f32_f32)(size_t, size_t, size_t, const double *, const float *, float *, double) = _matmul_f64_f32_f32;
-int (*matmul_f64_f32_f64)(size_t, size_t, size_t, const double *, const float *, double *,
- double) = _matmul_f64_f32_f64;
-int (*matmul_f64_f32_i8)(size_t, size_t, size_t, const double *, const float *, int8_t *, double) = _matmul_f64_f32_i8;
-int (*matmul_f64_f32_u8)(size_t, size_t, size_t, const double *, const float *, uint8_t *, double) = _matmul_f64_f32_u8;
-int (*matmul_f64_f64_f32)(size_t, size_t, size_t, const double *, const double *, float *,
- double) = _matmul_f64_f64_f32;
-int (*matmul_f64_f64_f64)(size_t, size_t, size_t, const double *, const double *, double *,
- double) = _matmul_f64_f64_f64;
-int (*matmul_f64_f64_i8)(size_t, size_t, size_t, const double *, const double *, int8_t *, double) = _matmul_f64_f64_i8;
-int (*matmul_f64_f64_u8)(size_t, size_t, size_t, const double *, const double *, uint8_t *,
- double) = _matmul_f64_f64_u8;
-int (*matmul_f64_i8_f32)(size_t, size_t, size_t, const double *, const int8_t *, float *, double) = _matmul_f64_i8_f32;
-int (*matmul_f64_i8_f64)(size_t, size_t, size_t, const double *, const int8_t *, double *, double) = _matmul_f64_i8_f64;
-int (*matmul_f64_i8_i8)(size_t, size_t, size_t, const double *, const int8_t *, int8_t *, double) = _matmul_f64_i8_i8;
-int (*matmul_f64_i8_u8)(size_t, size_t, size_t, const double *, const int8_t *, uint8_t *, double) = _matmul_f64_i8_u8;
-int (*matmul_f64_u8_f32)(size_t, size_t, size_t, const double *, const uint8_t *, float *, double) = _matmul_f64_u8_f32;
-int (*matmul_f64_u8_f64)(size_t, size_t, size_t, const double *, const uint8_t *, double *,
- double) = _matmul_f64_u8_f64;
-int (*matmul_f64_u8_i8)(size_t, size_t, size_t, const double *, const uint8_t *, int8_t *, double) = _matmul_f64_u8_i8;
-int (*matmul_f64_u8_u8)(size_t, size_t, size_t, const double *, const uint8_t *, uint8_t *, double) = _matmul_f64_u8_u8;
-int (*matmul_i8_f32_f32)(size_t, size_t, size_t, const int8_t *, const float *, float *, double) = _matmul_i8_f32_f32;
-int (*matmul_i8_f32_f64)(size_t, size_t, size_t, const int8_t *, const float *, double *, double) = _matmul_i8_f32_f64;
-int (*matmul_i8_f32_i8)(size_t, size_t, size_t, const int8_t *, const float *, int8_t *, double) = _matmul_i8_f32_i8;
-int (*matmul_i8_f32_u8)(size_t, size_t, size_t, const int8_t *, const float *, uint8_t *, double) = _matmul_i8_f32_u8;
-int (*matmul_i8_f64_f32)(size_t, size_t, size_t, const int8_t *, const double *, float *, double) = _matmul_i8_f64_f32;
-int (*matmul_i8_f64_f64)(size_t, size_t, size_t, const int8_t *, const double *, double *, double) = _matmul_i8_f64_f64;
-int (*matmul_i8_f64_i8)(size_t, size_t, size_t, const int8_t *, const double *, int8_t *, double) = _matmul_i8_f64_i8;
-int (*matmul_i8_f64_u8)(size_t, size_t, size_t, const int8_t *, const double *, uint8_t *, double) = _matmul_i8_f64_u8;
-int (*matmul_i8_i8_f32)(size_t, size_t, size_t, const int8_t *, const int8_t *, float *, double) = _matmul_i8_i8_f32;
-int (*matmul_i8_i8_f64)(size_t, size_t, size_t, const int8_t *, const int8_t *, double *, double) = _matmul_i8_i8_f64;
-int (*matmul_i8_i8_i8)(size_t, size_t, size_t, const int8_t *, const int8_t *, int8_t *, double) = _matmul_i8_i8_i8;
-int (*matmul_i8_i8_u8)(size_t, size_t, size_t, const int8_t *, const int8_t *, uint8_t *, double) = _matmul_i8_i8_u8;
-int (*matmul_i8_u8_f32)(size_t, size_t, size_t, const int8_t *, const uint8_t *, float *, double) = _matmul_i8_u8_f32;
-int (*matmul_i8_u8_f64)(size_t, size_t, size_t, const int8_t *, const uint8_t *, double *, double) = _matmul_i8_u8_f64;
-int (*matmul_i8_u8_i8)(size_t, size_t, size_t, const int8_t *, const uint8_t *, int8_t *, double) = _matmul_i8_u8_i8;
-int (*matmul_i8_u8_u8)(size_t, size_t, size_t, const int8_t *, const uint8_t *, uint8_t *, double) = _matmul_i8_u8_u8;
-int (*matmul_u8_f32_f32)(size_t, size_t, size_t, const uint8_t *, const float *, float *, double) = _matmul_u8_f32_f32;
-int (*matmul_u8_f32_f64)(size_t, size_t, size_t, const uint8_t *, const float *, double *, double) = _matmul_u8_f32_f64;
-int (*matmul_u8_f32_i8)(size_t, size_t, size_t, const uint8_t *, const float *, int8_t *, double) = _matmul_u8_f32_i8;
-int (*matmul_u8_f32_u8)(size_t, size_t, size_t, const uint8_t *, const float *, uint8_t *, double) = _matmul_u8_f32_u8;
-int (*matmul_u8_f64_f32)(size_t, size_t, size_t, const uint8_t *, const double *, float *, double) = _matmul_u8_f64_f32;
-int (*matmul_u8_f64_f64)(size_t, size_t, size_t, const uint8_t *, const double *, double *,
- double) = _matmul_u8_f64_f64;
-int (*matmul_u8_f64_i8)(size_t, size_t, size_t, const uint8_t *, const double *, int8_t *, double) = _matmul_u8_f64_i8;
-int (*matmul_u8_f64_u8)(size_t, size_t, size_t, const uint8_t *, const double *, uint8_t *, double) = _matmul_u8_f64_u8;
-int (*matmul_u8_i8_f32)(size_t, size_t, size_t, const uint8_t *, const int8_t *, float *, double) = _matmul_u8_i8_f32;
-int (*matmul_u8_i8_f64)(size_t, size_t, size_t, const uint8_t *, const int8_t *, double *, double) = _matmul_u8_i8_f64;
-int (*matmul_u8_i8_i8)(size_t, size_t, size_t, const uint8_t *, const int8_t *, int8_t *, double) = _matmul_u8_i8_i8;
-int (*matmul_u8_i8_u8)(size_t, size_t, size_t, const uint8_t *, const int8_t *, uint8_t *, double) = _matmul_u8_i8_u8;
-int (*matmul_u8_u8_f32)(size_t, size_t, size_t, const uint8_t *, const uint8_t *, float *, double) = _matmul_u8_u8_f32;
-int (*matmul_u8_u8_f64)(size_t, size_t, size_t, const uint8_t *, const uint8_t *, double *, double) = _matmul_u8_u8_f64;
-int (*matmul_u8_u8_i8)(size_t, size_t, size_t, const uint8_t *, const uint8_t *, int8_t *, double) = _matmul_u8_u8_i8;
-int (*matmul_u8_u8_u8)(size_t, size_t, size_t, const uint8_t *, const uint8_t *, uint8_t *, double) = _matmul_u8_u8_u8;
+int (*matmul_u8_i8_u8)(size_t, size_t, size_t, const uint8_t *, const int8_t *, uint8_t *, double) = _matmul_u8_i8_u8;
diff --git a/test/benchmark.c b/test/benchmark.c
@@ -1,148 +1,99 @@
-#include "nemequ/munit.h"
-#include "../src/matmul.h"
+#define _POSIX_C_SOURCE 199309L
+#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
+#include <string.h>
#include <time.h>
-static float *alloc_f32(size_t size) {
- return (float*)calloc(size, sizeof(float));
-}
+#include "../src/matmul.h"
-static double *alloc_f64(size_t size) {
- return (double*)calloc(size, sizeof(double));
-}
+#define RUNS 100
-static int8_t *alloc_i8(size_t size) {
- return (int8_t*)calloc(size, sizeof(int8_t));
+static struct timespec timespec_now(void) {
+ struct timespec ts;
+ clock_gettime(CLOCK_MONOTONIC, &ts);
+ return ts;
}
-static uint8_t *alloc_u8(size_t size) {
- return (uint8_t*)calloc(size, sizeof(uint8_t));
+static double timespec_diff_ms(struct timespec start, struct timespec end) {
+ return (end.tv_sec - start.tv_sec) * 1000.0 + (end.tv_nsec - start.tv_nsec) / 1e6;
}
-static double get_time_ms(void) {
- struct timespec ts;
- clock_gettime(CLOCK_MONOTONIC, &ts);
- return ts.tv_sec * 1000.0 + ts.tv_nsec / 1000000.0;
+static int compare_double(const void *a, const void *b) {
+ double da = *(const double *)a;
+ double db = *(const double *)b;
+ return (da > db) - (da < db);
}
-static MunitResult test_bench_f32_128x128(const MunitParameter *params, void *data) {
- (void)params; (void)data;
-
- size_t m = 128, n = 128, p = 128;
- size_t size = m * n;
- float *A = alloc_f32(size);
- float *B = alloc_f32(size);
- float *C = alloc_f32(m * p);
-
- for (size_t i = 0; i < size; i++) {
- A[i] = (float)(i % 100) * 0.1f;
- B[i] = (float)((i + 7) % 100) * 0.1f;
- }
-
- double start = get_time_ms();
- for (int iter = 0; iter < 10; iter++) {
- matmul(m, n, p, A, B, C);
- }
- double elapsed = get_time_ms() - start;
-
- matmul_feature_t feat = matmul_get_feature();
- printf("f32 128x128: %.2f ms (feature: %s)\n", elapsed / 10.0, matmul_get_feature_name(feat));
-
- free(A); free(B); free(C);
- return MUNIT_OK;
+static double percentile(double *sorted, double p, int n) {
+ double idx = p / 100.0 * (n - 1);
+ int lo = (int)idx;
+ int hi = lo + 1;
+ double frac = idx - lo;
+ if (hi >= n) hi = n - 1;
+ return sorted[lo] * (1.0 - frac) + sorted[hi] * frac;
}
-static MunitResult test_bench_f64_128x128(const MunitParameter *params, void *data) {
- (void)params; (void)data;
-
- size_t m = 128, n = 128, p = 128;
- size_t size = m * n;
- double *A = alloc_f64(size);
- double *B = alloc_f64(size);
- double *C = alloc_f64(m * p);
-
- for (size_t i = 0; i < size; i++) {
- A[i] = (double)(i % 100) * 0.1;
- B[i] = (double)((i + 7) % 100) * 0.1;
- }
-
- double start = get_time_ms();
- for (int iter = 0; iter < 10; iter++) {
- matmul_f64(m, n, p, A, B, C);
- }
- double elapsed = get_time_ms() - start;
-
- printf("f64 128x128: %.2f ms\n", elapsed / 10.0);
-
- free(A); free(B); free(C);
- return MUNIT_OK;
-}
+static void bench(size_t m, size_t n, size_t p, int runs) {
+ uint8_t *A = malloc(m * n);
+ int8_t *B = malloc(n * p);
+ uint8_t *C = malloc(m * p);
+ uint8_t *Cwarm = malloc(m * p);
+ double times[RUNS];
-static MunitResult test_bench_i8_128x128(const MunitParameter *params, void *data) {
- (void)params; (void)data;
-
- size_t m = 128, n = 128, p = 128;
- size_t size = m * n;
- int8_t *A = alloc_i8(size);
- int8_t *B = alloc_i8(size);
- int8_t *C = alloc_i8(m * p);
-
- for (size_t i = 0; i < size; i++) {
- A[i] = (int8_t)(i % 256 - 128);
- B[i] = (int8_t)((i + 7) % 256 - 128);
- }
-
- double start = get_time_ms();
- for (int iter = 0; iter < 10; iter++) {
- matmul_i8(m, n, p, A, B, C);
- }
- double elapsed = get_time_ms() - start;
-
- printf("i8 128x128: %.2f ms\n", elapsed / 10.0);
-
- free(A); free(B); free(C);
- return MUNIT_OK;
-}
+ if (!A || !B || !C || !Cwarm) {
+ fprintf(stderr, "OOM for %zu x %zu\n", m, n);
+ free(A);
+ free(B);
+ free(C);
+ free(Cwarm);
+ return;
+ }
+
+ for (size_t i = 0; i < m * n; i++) A[i] = (uint8_t)(rand() % 256);
+ for (size_t i = 0; i < n * p; i++) B[i] = (int8_t)(rand() % 256);
+ memset(C, 0, m * p);
+ memset(Cwarm, 0, m * p);
+
+ matmul_u8_i8_u8(m, n, p, A, B, Cwarm, 0.0);
+
+ int actual_runs = runs;
+ if (m >= 4096) actual_runs = 3;
-static MunitResult test_bench_u8_128x128(const MunitParameter *params, void *data) {
- (void)params; (void)data;
-
- size_t m = 128, n = 128, p = 128;
- size_t size = m * n;
- uint8_t *A = alloc_u8(size);
- uint8_t *B = alloc_u8(size);
- uint8_t *C = alloc_u8(m * p);
-
- for (size_t i = 0; i < size; i++) {
- A[i] = (uint8_t)(i % 256);
- B[i] = (uint8_t)((i + 7) % 256);
- }
-
- double start = get_time_ms();
- for (int iter = 0; iter < 10; iter++) {
- matmul_u8(m, n, p, A, B, C);
- }
- double elapsed = get_time_ms() - start;
-
- printf("u8 128x128: %.2f ms\n", elapsed / 10.0);
-
- free(A); free(B); free(C);
- return MUNIT_OK;
+ for (int r = 0; r < actual_runs; r++) {
+ memset(C, 0, m * p);
+ struct timespec start = timespec_now();
+ matmul_u8_i8_u8(m, n, p, A, B, C, 0.0);
+ struct timespec end = timespec_now();
+ times[r] = timespec_diff_ms(start, end);
+ }
+
+ qsort(times, actual_runs, sizeof(double), compare_double);
+
+ double gflops = 2.0 * m * n * p / (percentile(times, 50, actual_runs) * 1e6);
+
+ printf("%8zu x %8zu | %6.2f | %6.2f | %6.2f | %6.2f | %6.2f | %8.1f\n", m, n, percentile(times, 1, actual_runs),
+ percentile(times, 5, actual_runs), percentile(times, 50, actual_runs), percentile(times, 95, actual_runs),
+ percentile(times, 99, actual_runs), gflops);
}
-static MunitTest tests[] = {
- {"/f32-128x128", test_bench_f32_128x128, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
- {"/f64-128x128", test_bench_f64_128x128, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
- {"/i8-128x128", test_bench_i8_128x128, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
- {"/u8-128x128", test_bench_u8_128x128, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
- {NULL, NULL, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}
-};
-
-static const MunitSuite suite = {
- {"/benchmark", tests, NULL, 1, MUNIT_SUITE_OPTION_NONE},
-};
-
-int main(int argc, char *argv[MUNIT_ARRAY_PARAM(argc)]) {
- return munit_suite_main(&suite, NULL, argc, argv);
-}
-\ No newline at end of file
+int main(void) {
+ srand(42);
+
+ printf("Benchmark: u8_i8_u8 matmul, %d runs per size\n", RUNS);
+ printf("--------------------------------------------------------------\n");
+ printf("%8s | %8s | %8s | %8s | %8s | %8s | %8s\n", "M x N", "1% (ms)", "5% (ms)", "50% (ms)", "95% (ms)", "99% (ms)",
+ "GFLOPS");
+ printf("--------------------------------------------------------------\n");
+
+ bench(16, 16, 16, RUNS);
+ bench(64, 64, 64, RUNS);
+ bench(256, 256, 256, RUNS);
+ bench(1024, 1024, 1024, RUNS);
+ bench(4096, 4096, 4096, RUNS);
+ // bench(16384, 16384, 16384, RUNS);
+
+ printf("--------------------------------------------------------------\n");
+
+ return 0;
+}
diff --git a/test/test_matmul.c b/test/test_matmul.c
@@ -1,324 +1,146 @@
-#include "nemequ/munit.h"
-#include "../src/matmul.h"
+#include <math.h>
+#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
-#include <math.h>
-#include <stdint.h>
-#define f32 float
-#define f64 double
-#define i8 int8_t
-#define u8 uint8_t
+#include "../src/matmul.h"
+#include "nemequ/munit.h"
+#include "test_matmul_simd.h"
-static void ref_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const float *B, float *C) {
- for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { float s = 0; for (size_t k = 0; k < n; k++) s += A[i*n+k] * B[k*p+j]; C[i*p+j] = s; }
-}
-static void ref_f32_f32_f64(size_t m, size_t n, size_t p, const float *A, const float *B, double *C) {
- for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { double s = 0; for (size_t k = 0; k < n; k++) s += (double)A[i*n+k] * (double)B[k*p+j]; C[i*p+j] = s; }
-}
-static void ref_f32_f32_i8(size_t m, size_t n, size_t p, const float *A, const float *B, int8_t *C) {
- for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { int s = 0; for (size_t k = 0; k < n; k++) s += (int)A[i*n+k] * (int)B[k*p+j]; C[i*p+j] = (s > 127) ? 127 : (s < -128 ? -128 : s); }
-}
-static void ref_f32_f32_u8(size_t m, size_t n, size_t p, const float *A, const float *B, uint8_t *C) {
- for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { int s = 0; for (size_t k = 0; k < n; k++) s += (int)A[i*n+k] * (int)B[k*p+j]; C[i*p+j] = (s > 255) ? 255 : (s < 0 ? 0 : s); }
-}
-static void ref_f32_f64_f32(size_t m, size_t n, size_t p, const float *A, const double *B, float *C) {
- for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { float s = 0; for (size_t k = 0; k < n; k++) s += (float)A[i*n+k] * (float)B[k*p+j]; C[i*p+j] = s; }
-}
-static void ref_f32_f64_f64(size_t m, size_t n, size_t p, const float *A, const double *B, double *C) {
- for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { double s = 0; for (size_t k = 0; k < n; k++) s += (double)A[i*n+k] * B[k*p+j]; C[i*p+j] = s; }
-}
-static void ref_f32_f64_i8(size_t m, size_t n, size_t p, const float *A, const double *B, int8_t *C) {
- for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { int s = 0; for (size_t k = 0; k < n; k++) s += (int)A[i*n+k] * (int)B[k*p+j]; C[i*p+j] = (s > 127) ? 127 : (s < -128 ? -128 : s); }
-}
-static void ref_f32_f64_u8(size_t m, size_t n, size_t p, const float *A, const double *B, uint8_t *C) {
- for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { int s = 0; for (size_t k = 0; k < n; k++) s += (int)A[i*n+k] * (int)B[k*p+j]; C[i*p+j] = (s > 255) ? 255 : (s < 0 ? 0 : s); }
-}
-static void ref_f32_i8_f32(size_t m, size_t n, size_t p, const float *A, const int8_t *B, float *C) {
- for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { float s = 0; for (size_t k = 0; k < n; k++) s += (float)A[i*n+k] * (float)B[k*p+j]; C[i*p+j] = s; }
-}
-static void ref_f32_i8_f64(size_t m, size_t n, size_t p, const float *A, const int8_t *B, double *C) {
- for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { double s = 0; for (size_t k = 0; k < n; k++) s += (double)A[i*n+k] * (double)B[k*p+j]; C[i*p+j] = s; }
-}
-static void ref_f32_i8_i8(size_t m, size_t n, size_t p, const float *A, const int8_t *B, int8_t *C) {
- for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { int s = 0; for (size_t k = 0; k < n; k++) s += (int)A[i*n+k] * (int)B[k*p+j]; C[i*p+j] = (s > 127) ? 127 : (s < -128 ? -128 : s); }
-}
-static void ref_f32_i8_u8(size_t m, size_t n, size_t p, const float *A, const int8_t *B, uint8_t *C) {
- for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { int s = 0; for (size_t k = 0; k < n; k++) s += (int)A[i*n+k] * (int)B[k*p+j]; C[i*p+j] = (s > 255) ? 255 : (s < 0 ? 0 : s); }
-}
-static void ref_f32_u8_f32(size_t m, size_t n, size_t p, const float *A, const uint8_t *B, float *C) {
- for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { float s = 0; for (size_t k = 0; k < n; k++) s += (float)A[i*n+k] * (float)B[k*p+j]; C[i*p+j] = s; }
-}
-static void ref_f32_u8_f64(size_t m, size_t n, size_t p, const float *A, const uint8_t *B, double *C) {
- for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { double s = 0; for (size_t k = 0; k < n; k++) s += (double)A[i*n+k] * (double)B[k*p+j]; C[i*p+j] = s; }
-}
-static void ref_f32_u8_i8(size_t m, size_t n, size_t p, const float *A, const uint8_t *B, int8_t *C) {
- for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { int s = 0; for (size_t k = 0; k < n; k++) s += (int)A[i*n+k] * (int)B[k*p+j]; C[i*p+j] = (s > 127) ? 127 : (s < -128 ? -128 : s); }
-}
-static void ref_f32_u8_u8(size_t m, size_t n, size_t p, const float *A, const uint8_t *B, uint8_t *C) {
- for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { int s = 0; for (size_t k = 0; k < n; k++) s += (int)A[i*n+k] * (int)B[k*p+j]; C[i*p+j] = (s > 255) ? 255 : (s < 0 ? 0 : s); }
-}
-static void ref_f64_f32_f32(size_t m, size_t n, size_t p, const double *A, const float *B, float *C) {
- for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { float s = 0; for (size_t k = 0; k < n; k++) s += (float)A[i*n+k] * (float)B[k*p+j]; C[i*p+j] = s; }
-}
-static void ref_f64_f32_f64(size_t m, size_t n, size_t p, const double *A, const float *B, double *C) {
- for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { double s = 0; for (size_t k = 0; k < n; k++) s += A[i*n+k] * (double)B[k*p+j]; C[i*p+j] = s; }
-}
-static void ref_f64_f32_i8(size_t m, size_t n, size_t p, const double *A, const float *B, int8_t *C) {
- for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { int s = 0; for (size_t k = 0; k < n; k++) s += (int)A[i*n+k] * (int)B[k*p+j]; C[i*p+j] = (s > 127) ? 127 : (s < -128 ? -128 : s); }
-}
-static void ref_f64_f32_u8(size_t m, size_t n, size_t p, const double *A, const float *B, uint8_t *C) {
- for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { int s = 0; for (size_t k = 0; k < n; k++) s += (int)A[i*n+k] * (int)B[k*p+j]; C[i*p+j] = (s > 255) ? 255 : (s < 0 ? 0 : s); }
-}
-static void ref_f64_f64_f32(size_t m, size_t n, size_t p, const double *A, const double *B, float *C) {
- for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { float s = 0; for (size_t k = 0; k < n; k++) s += (float)A[i*n+k] * (float)B[k*p+j]; C[i*p+j] = s; }
-}
-static void ref_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, const double *B, double *C) {
- for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { double s = 0; for (size_t k = 0; k < n; k++) s += A[i*n+k] * B[k*p+j]; C[i*p+j] = s; }
-}
-static void ref_f64_f64_i8(size_t m, size_t n, size_t p, const double *A, const double *B, int8_t *C) {
- for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { int s = 0; for (size_t k = 0; k < n; k++) s += (int)A[i*n+k] * (int)B[k*p+j]; C[i*p+j] = (s > 127) ? 127 : (s < -128 ? -128 : s); }
-}
-static void ref_f64_f64_u8(size_t m, size_t n, size_t p, const double *A, const double *B, uint8_t *C) {
- for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { int s = 0; for (size_t k = 0; k < n; k++) s += (int)A[i*n+k] * (int)B[k*p+j]; C[i*p+j] = (s > 255) ? 255 : (s < 0 ? 0 : s); }
-}
-static void ref_f64_i8_f32(size_t m, size_t n, size_t p, const double *A, const int8_t *B, float *C) {
- for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { float s = 0; for (size_t k = 0; k < n; k++) s += (float)A[i*n+k] * (float)B[k*p+j]; C[i*p+j] = s; }
-}
-static void ref_f64_i8_f64(size_t m, size_t n, size_t p, const double *A, const int8_t *B, double *C) {
- for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { double s = 0; for (size_t k = 0; k < n; k++) s += A[i*n+k] * (double)B[k*p+j]; C[i*p+j] = s; }
-}
-static void ref_f64_i8_i8(size_t m, size_t n, size_t p, const double *A, const int8_t *B, int8_t *C) {
- for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { int s = 0; for (size_t k = 0; k < n; k++) s += (int)A[i*n+k] * (int)B[k*p+j]; C[i*p+j] = (s > 127) ? 127 : (s < -128 ? -128 : s); }
-}
-static void ref_f64_i8_u8(size_t m, size_t n, size_t p, const double *A, const int8_t *B, uint8_t *C) {
- for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { int s = 0; for (size_t k = 0; k < n; k++) s += (int)A[i*n+k] * (int)B[k*p+j]; C[i*p+j] = (s > 255) ? 255 : (s < 0 ? 0 : s); }
-}
-static void ref_f64_u8_f32(size_t m, size_t n, size_t p, const double *A, const uint8_t *B, float *C) {
- for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { float s = 0; for (size_t k = 0; k < n; k++) s += (float)A[i*n+k] * (float)B[k*p+j]; C[i*p+j] = s; }
-}
-static void ref_f64_u8_f64(size_t m, size_t n, size_t p, const double *A, const uint8_t *B, double *C) {
- for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { double s = 0; for (size_t k = 0; k < n; k++) s += A[i*n+k] * (double)B[k*p+j]; C[i*p+j] = s; }
-}
-static void ref_f64_u8_i8(size_t m, size_t n, size_t p, const double *A, const uint8_t *B, int8_t *C) {
- for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { int s = 0; for (size_t k = 0; k < n; k++) s += (int)A[i*n+k] * (int)B[k*p+j]; C[i*p+j] = (s > 127) ? 127 : (s < -128 ? -128 : s); }
-}
-static void ref_f64_u8_u8(size_t m, size_t n, size_t p, const double *A, const uint8_t *B, uint8_t *C) {
- for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { int s = 0; for (size_t k = 0; k < n; k++) s += (int)A[i*n+k] * (int)B[k*p+j]; C[i*p+j] = (s > 255) ? 255 : (s < 0 ? 0 : s); }
-}
-static void ref_i8_f32_f32(size_t m, size_t n, size_t p, const int8_t *A, const float *B, float *C) {
- for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { float s = 0; for (size_t k = 0; k < n; k++) s += (float)A[i*n+k] * B[k*p+j]; C[i*p+j] = s; }
-}
-static void ref_i8_f32_f64(size_t m, size_t n, size_t p, const int8_t *A, const float *B, double *C) {
- for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { double s = 0; for (size_t k = 0; k < n; k++) s += (double)A[i*n+k] * (double)B[k*p+j]; C[i*p+j] = s; }
-}
-static void ref_i8_f32_i8(size_t m, size_t n, size_t p, const int8_t *A, const float *B, int8_t *C) {
- for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { int s = 0; for (size_t k = 0; k < n; k++) s += (int)A[i*n+k] * (int)B[k*p+j]; C[i*p+j] = (s > 127) ? 127 : (s < -128 ? -128 : s); }
-}
-static void ref_i8_f32_u8(size_t m, size_t n, size_t p, const int8_t *A, const float *B, uint8_t *C) {
- for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { int s = 0; for (size_t k = 0; k < n; k++) s += (int)A[i*n+k] * (int)B[k*p+j]; C[i*p+j] = (s > 255) ? 255 : (s < 0 ? 0 : s); }
-}
-static void ref_i8_f64_f32(size_t m, size_t n, size_t p, const int8_t *A, const double *B, float *C) {
- for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { float s = 0; for (size_t k = 0; k < n; k++) s += (float)A[i*n+k] * (float)B[k*p+j]; C[i*p+j] = s; }
-}
-static void ref_i8_f64_f64(size_t m, size_t n, size_t p, const int8_t *A, const double *B, double *C) {
- for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { double s = 0; for (size_t k = 0; k < n; k++) s += (double)A[i*n+k] * B[k*p+j]; C[i*p+j] = s; }
-}
-static void ref_i8_f64_i8(size_t m, size_t n, size_t p, const int8_t *A, const double *B, int8_t *C) {
- for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { int s = 0; for (size_t k = 0; k < n; k++) s += (int)A[i*n+k] * (int)B[k*p+j]; C[i*p+j] = (s > 127) ? 127 : (s < -128 ? -128 : s); }
-}
-static void ref_i8_f64_u8(size_t m, size_t n, size_t p, const int8_t *A, const double *B, uint8_t *C) {
- for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { int s = 0; for (size_t k = 0; k < n; k++) s += (int)A[i*n+k] * (int)B[k*p+j]; C[i*p+j] = (s > 255) ? 255 : (s < 0 ? 0 : s); }
-}
-static void ref_i8_i8_f32(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, float *C) {
- for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { float s = 0; for (size_t k = 0; k < n; k++) s += (float)A[i*n+k] * (float)B[k*p+j]; C[i*p+j] = s; }
-}
-static void ref_i8_i8_f64(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, double *C) {
- for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { double s = 0; for (size_t k = 0; k < n; k++) s += (double)A[i*n+k] * (double)B[k*p+j]; C[i*p+j] = s; }
-}
-static void ref_i8_i8_i8(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, int8_t *C) {
- for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { int s = 0; for (size_t k = 0; k < n; k++) s += (int)A[i*n+k] * (int)B[k*p+j]; C[i*p+j] = (s > 127) ? 127 : (s < -128 ? -128 : s); }
-}
-static void ref_i8_i8_u8(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, uint8_t *C) {
- for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { int s = 0; for (size_t k = 0; k < n; k++) s += (int)A[i*n+k] * (int)B[k*p+j]; C[i*p+j] = (s > 255) ? 255 : (s < 0 ? 0 : s); }
-}
-static void ref_i8_u8_f32(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, float *C) {
- for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { float s = 0; for (size_t k = 0; k < n; k++) s += (float)A[i*n+k] * (float)B[k*p+j]; C[i*p+j] = s; }
-}
-static void ref_i8_u8_f64(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, double *C) {
- for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { double s = 0; for (size_t k = 0; k < n; k++) s += (double)A[i*n+k] * (double)B[k*p+j]; C[i*p+j] = s; }
-}
-static void ref_i8_u8_i8(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, int8_t *C) {
- for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { int s = 0; for (size_t k = 0; k < n; k++) s += (int)A[i*n+k] * (int)B[k*p+j]; C[i*p+j] = (s > 127) ? 127 : (s < -128 ? -128 : s); }
-}
-static void ref_i8_u8_u8(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, uint8_t *C) {
- for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { int s = 0; for (size_t k = 0; k < n; k++) s += (int)A[i*n+k] * (int)B[k*p+j]; C[i*p+j] = (s > 255) ? 255 : (s < 0 ? 0 : s); }
-}
-static void ref_u8_f32_f32(size_t m, size_t n, size_t p, const uint8_t *A, const float *B, float *C) {
- for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { float s = 0; for (size_t k = 0; k < n; k++) s += (float)A[i*n+k] * B[k*p+j]; C[i*p+j] = s; }
-}
-static void ref_u8_f32_f64(size_t m, size_t n, size_t p, const uint8_t *A, const float *B, double *C) {
- for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { double s = 0; for (size_t k = 0; k < n; k++) s += (double)A[i*n+k] * (double)B[k*p+j]; C[i*p+j] = s; }
-}
-static void ref_u8_f32_i8(size_t m, size_t n, size_t p, const uint8_t *A, const float *B, int8_t *C) {
- for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { int s = 0; for (size_t k = 0; k < n; k++) s += (int)A[i*n+k] * (int)B[k*p+j]; C[i*p+j] = (s > 127) ? 127 : (s < -128 ? -128 : s); }
-}
-static void ref_u8_f32_u8(size_t m, size_t n, size_t p, const uint8_t *A, const float *B, uint8_t *C) {
- for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { int s = 0; for (size_t k = 0; k < n; k++) s += (int)A[i*n+k] * (int)B[k*p+j]; C[i*p+j] = (s > 255) ? 255 : (s < 0 ? 0 : s); }
-}
-static void ref_u8_f64_f32(size_t m, size_t n, size_t p, const uint8_t *A, const double *B, float *C) {
- for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { float s = 0; for (size_t k = 0; k < n; k++) s += (float)A[i*n+k] * (float)B[k*p+j]; C[i*p+j] = s; }
-}
-static void ref_u8_f64_f64(size_t m, size_t n, size_t p, const uint8_t *A, const double *B, double *C) {
- for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { double s = 0; for (size_t k = 0; k < n; k++) s += (double)A[i*n+k] * B[k*p+j]; C[i*p+j] = s; }
-}
-static void ref_u8_f64_i8(size_t m, size_t n, size_t p, const uint8_t *A, const double *B, int8_t *C) {
- for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { int s = 0; for (size_t k = 0; k < n; k++) s += (int)A[i*n+k] * (int)B[k*p+j]; C[i*p+j] = (s > 127) ? 127 : (s < -128 ? -128 : s); }
-}
-static void ref_u8_f64_u8(size_t m, size_t n, size_t p, const uint8_t *A, const double *B, uint8_t *C) {
- for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { int s = 0; for (size_t k = 0; k < n; k++) s += (int)A[i*n+k] * (int)B[k*p+j]; C[i*p+j] = (s > 255) ? 255 : (s < 0 ? 0 : s); }
-}
-static void ref_u8_i8_f32(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, float *C) {
- for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { float s = 0; for (size_t k = 0; k < n; k++) s += (float)A[i*n+k] * (float)B[k*p+j]; C[i*p+j] = s; }
-}
-static void ref_u8_i8_f64(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, double *C) {
- for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { double s = 0; for (size_t k = 0; k < n; k++) s += (double)A[i*n+k] * (double)B[k*p+j]; C[i*p+j] = s; }
-}
-static void ref_u8_i8_i8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, int8_t *C) {
- for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { int s = 0; for (size_t k = 0; k < n; k++) s += (int)A[i*n+k] * (int)B[k*p+j]; C[i*p+j] = (s > 127) ? 127 : (s < -128 ? -128 : s); }
-}
static void ref_u8_i8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, uint8_t *C) {
- for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { int s = 0; for (size_t k = 0; k < n; k++) s += (int)A[i*n+k] * (int)B[k*p+j]; C[i*p+j] = (s > 255) ? 255 : (s < 0 ? 0 : s); }
+ for (size_t i = 0; i < m; i++)
+ for (size_t j = 0; j < p; j++) {
+ int sum = 0;
+ for (size_t k = 0; k < n; k++) sum += (int)A[i * n + k] * (int)B[k * p + j];
+ if (sum > 255) sum = 255;
+ if (sum < 0) sum = 0;
+ C[i * p + j] = (uint8_t)sum;
+ }
}
-static void ref_u8_u8_f32(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, float *C) {
- for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { float s = 0; for (size_t k = 0; k < n; k++) s += (float)A[i*n+k] * (float)B[k*p+j]; C[i*p+j] = s; }
+
+static MunitResult test_u8_i8_u8_small(const char *name,
+ int (*matmul_fn)(size_t, size_t, size_t, const uint8_t *, const int8_t *,
+ uint8_t *, double),
+ double epsilon) {
+ uint8_t A[] = {1, 2, 3, 4, 5, 6};
+ int8_t B[] = {1, 0, 0, 1, 0, 0};
+ uint8_t C[4], E[4];
+
+ ref_u8_i8_u8(2, 3, 2, A, B, E);
+ matmul_fn(2, 3, 2, A, B, C, 0.0);
+
+ for (int i = 0; i < 4; i++) {
+ int d = E[i] - C[i];
+ if (d < 0) d = -d;
+ if (d > (int)epsilon) return MUNIT_FAIL;
+ }
+ return MUNIT_OK;
+}
+
+static MunitResult test_u8_i8_u8_medium(const char *name,
+ int (*matmul_fn)(size_t, size_t, size_t, const uint8_t *, const int8_t *,
+ uint8_t *, double),
+ double epsilon) {
+ const size_t m = 64, n = 64, p = 64;
+ uint8_t *A = malloc(m * n);
+ int8_t *B = malloc(n * p);
+ uint8_t *C = malloc(m * p);
+ uint8_t *E = malloc(m * p);
+ if (!A || !B || !C || !E) {
+ free(A);
+ free(B);
+ free(C);
+ free(E);
+ return MUNIT_SKIP;
+ }
+
+ // Deterministic pseudo-random values
+ for (size_t i = 0; i < m * n; i++) A[i] = (uint8_t)((i * 7 + 13) % 251);
+ for (size_t i = 0; i < n * p; i++) B[i] = (int8_t)(((i * 11 + 17) % 211) - 105);
+ memset(C, 0, m * p);
+ memset(E, 0, m * p);
+
+ ref_u8_i8_u8(m, n, p, A, B, E);
+ matmul_fn(m, n, p, A, B, C, 0.0);
+
+ for (size_t i = 0; i < m * p; i++) {
+ int d = (int)E[i] - (int)C[i];
+ if (d < 0) d = -d;
+ if (d > (int)epsilon) {
+ free(A);
+ free(B);
+ free(C);
+ free(E);
+ return MUNIT_FAIL;
+ }
+ }
+
+ free(A);
+ free(B);
+ free(C);
+ free(E);
+ return MUNIT_OK;
}
-static void ref_u8_u8_f64(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, double *C) {
- for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { double s = 0; for (size_t k = 0; k < n; k++) s += (double)A[i*n+k] * (double)B[k*p+j]; C[i*p+j] = s; }
+
+static MunitResult test_u8_i8_u8(const char *name,
+ int (*matmul_fn)(size_t, size_t, size_t, const uint8_t *, const int8_t *, uint8_t *,
+ double),
+ double epsilon) {
+ return test_u8_i8_u8_small(name, matmul_fn, epsilon);
}
-static void ref_u8_u8_i8(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, int8_t *C) {
- for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { int s = 0; for (size_t k = 0; k < n; k++) s += (int)A[i*n+k] * (int)B[k*p+j]; C[i*p+j] = (s > 127) ? 127 : (s < -128 ? -128 : s); }
+
+static MunitResult test_scalar_u8_i8_u8(const MunitParameter *params, void *data) {
+ (void)params;
+ (void)data;
+ return test_u8_i8_u8("scalar", matmul_scalar_u8_i8_u8, 0);
}
-static void ref_u8_u8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, uint8_t *C) {
- for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { int s = 0; for (size_t k = 0; k < n; k++) s += (int)A[i*n+k] * (int)B[k*p+j]; C[i*p+j] = (s > 255) ? 255 : (s < 0 ? 0 : s); }
+
+static MunitResult test_scalar_u8_i8_u8_medium(const MunitParameter *params, void *data) {
+ (void)params;
+ (void)data;
+ return test_u8_i8_u8_medium("scalar", matmul_scalar_u8_i8_u8, 0);
}
-#define TEST_FLT(ta, tb, tc, backend) \
-static MunitResult test_##backend##_##ta##_##tb##_##tc(const MunitParameter *params, void *data) { \
- (void)params; (void)data; \
- ta A[] = {1, 2, 3, 4, 5, 6}; tb B[] = {1, 0, 0, 1, 0, 0}; tc C[4], E[4]; \
- ref_##ta##_##tb##_##tc(2, 3, 2, A, B, E); \
- matmul_##backend##_##ta##_##tb##_##tc(2, 3, 2, A, B, C, 0.0); \
- for (int i = 0; i < 4; i++) { float d = E[i] - C[i]; if (d < 0) d = -d; if (d > 1e-5f) return MUNIT_FAIL; } \
- return MUNIT_OK; \
+#ifdef __AVX512VNNI__
+static MunitResult test_avx512vnni_u8_i8_u8(const MunitParameter *params, void *data) {
+ (void)params;
+ (void)data;
+ return test_u8_i8_u8("avx512vnni", matmul_avx512vnni_u8_i8_u8, 0);
}
-#define TEST_DBL(ta, tb, tc, backend) \
-static MunitResult test_##backend##_##ta##_##tb##_##tc(const MunitParameter *params, void *data) { \
- (void)params; (void)data; \
- ta A[] = {1, 2, 3, 4, 5, 6}; tb B[] = {1, 0, 0, 1, 0, 0}; tc C[4], E[4]; \
- ref_##ta##_##tb##_##tc(2, 3, 2, A, B, E); \
- matmul_##backend##_##ta##_##tb##_##tc(2, 3, 2, A, B, C, 0.0); \
- for (int i = 0; i < 4; i++) { double d = E[i] - C[i]; if (d < 0) d = -d; if (d > 1e-10) return MUNIT_FAIL; } \
- return MUNIT_OK; \
+
+static MunitResult test_avx512vnni_u8_i8_u8_medium(const MunitParameter *params, void *data) {
+ (void)params;
+ (void)data;
+ return test_u8_i8_u8_medium("avx512vnni", matmul_avx512vnni_u8_i8_u8, 0);
}
-#define TEST_INT(ta, tb, tc, backend) \
-static MunitResult test_##backend##_##ta##_##tb##_##tc(const MunitParameter *params, void *data) { \
- (void)params; (void)data; \
- ta A[] = {1, 2, 3, 4, 5, 6}; tb B[] = {1, 0, 0, 1, 0, 0}; tc C[4], E[4]; \
- ref_##ta##_##tb##_##tc(2, 3, 2, A, B, E); \
- matmul_##backend##_##ta##_##tb##_##tc(2, 3, 2, A, B, C, 0.0); \
- for (int i = 0; i < 4; i++) if (E[i] != C[i]) return MUNIT_FAIL; \
- return MUNIT_OK; \
+#endif
+
+static MunitResult test_dispatched_u8_i8_u8(const MunitParameter *params, void *data) {
+ (void)params;
+ (void)data;
+ return test_u8_i8_u8("dispatched", matmul_u8_i8_u8, 0);
}
-TEST_FLT(f32, f32, f32, scalar) TEST_DBL(f32, f32, f64, scalar) TEST_INT(f32, f32, i8, scalar) TEST_INT(f32, f32, u8, scalar)
-TEST_FLT(f32, f64, f32, scalar) TEST_DBL(f32, f64, f64, scalar) TEST_INT(f32, f64, i8, scalar) TEST_INT(f32, f64, u8, scalar)
-TEST_FLT(f32, i8, f32, scalar) TEST_DBL(f32, i8, f64, scalar) TEST_INT(f32, i8, i8, scalar) TEST_INT(f32, i8, u8, scalar)
-TEST_FLT(f32, u8, f32, scalar) TEST_DBL(f32, u8, f64, scalar) TEST_INT(f32, u8, i8, scalar) TEST_INT(f32, u8, u8, scalar)
-TEST_FLT(f64, f32, f32, scalar) TEST_DBL(f64, f32, f64, scalar) TEST_INT(f64, f32, i8, scalar) TEST_INT(f64, f32, u8, scalar)
-TEST_FLT(f64, f64, f32, scalar) TEST_DBL(f64, f64, f64, scalar) TEST_INT(f64, f64, i8, scalar) TEST_INT(f64, f64, u8, scalar)
-TEST_FLT(f64, i8, f32, scalar) TEST_DBL(f64, i8, f64, scalar) TEST_INT(f64, i8, i8, scalar) TEST_INT(f64, i8, u8, scalar)
-TEST_FLT(f64, u8, f32, scalar) TEST_DBL(f64, u8, f64, scalar) TEST_INT(f64, u8, i8, scalar) TEST_INT(f64, u8, u8, scalar)
-TEST_FLT(i8, f32, f32, scalar) TEST_DBL(i8, f32, f64, scalar) TEST_INT(i8, f32, i8, scalar) TEST_INT(i8, f32, u8, scalar)
-TEST_FLT(i8, f64, f32, scalar) TEST_DBL(i8, f64, f64, scalar) TEST_INT(i8, f64, i8, scalar) TEST_INT(i8, f64, u8, scalar)
-TEST_FLT(i8, i8, f32, scalar) TEST_DBL(i8, i8, f64, scalar) TEST_INT(i8, i8, i8, scalar) TEST_INT(i8, i8, u8, scalar)
-TEST_FLT(i8, u8, f32, scalar) TEST_DBL(i8, u8, f64, scalar) TEST_INT(i8, u8, i8, scalar) TEST_INT(i8, u8, u8, scalar)
-TEST_FLT(u8, f32, f32, scalar) TEST_DBL(u8, f32, f64, scalar) TEST_INT(u8, f32, i8, scalar) TEST_INT(u8, f32, u8, scalar)
-TEST_FLT(u8, f64, f32, scalar) TEST_DBL(u8, f64, f64, scalar) TEST_INT(u8, f64, i8, scalar) TEST_INT(u8, f64, u8, scalar)
-TEST_FLT(u8, i8, f32, scalar) TEST_DBL(u8, i8, f64, scalar) TEST_INT(u8, i8, i8, scalar) TEST_INT(u8, i8, u8, scalar)
-TEST_FLT(u8, u8, f32, scalar) TEST_DBL(u8, u8, f64, scalar) TEST_INT(u8, u8, i8, scalar) TEST_INT(u8, u8, u8, scalar)
+static MunitResult test_dispatched_u8_i8_u8_medium(const MunitParameter *params, void *data) {
+ (void)params;
+ (void)data;
+ return test_u8_i8_u8_medium("dispatched", matmul_u8_i8_u8, 0);
+}
static MunitTest tests[] = {
- {"/scalar-f32-f32-f32", test_scalar_f32_f32_f32, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
- {"/scalar-f32-f32-f64", test_scalar_f32_f32_f64, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
- {"/scalar-f32-f32-i8", test_scalar_f32_f32_i8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
- {"/scalar-f32-f32-u8", test_scalar_f32_f32_u8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
- {"/scalar-f32-f64-f32", test_scalar_f32_f64_f32, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
- {"/scalar-f32-f64-f64", test_scalar_f32_f64_f64, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
- {"/scalar-f32-f64-i8", test_scalar_f32_f64_i8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
- {"/scalar-f32-f64-u8", test_scalar_f32_f64_u8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
- {"/scalar-f32-i8-f32", test_scalar_f32_i8_f32, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
- {"/scalar-f32-i8-f64", test_scalar_f32_i8_f64, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
- {"/scalar-f32-i8-i8", test_scalar_f32_i8_i8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
- {"/scalar-f32-i8-u8", test_scalar_f32_i8_u8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
- {"/scalar-f32-u8-f32", test_scalar_f32_u8_f32, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
- {"/scalar-f32-u8-f64", test_scalar_f32_u8_f64, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
- {"/scalar-f32-u8-i8", test_scalar_f32_u8_i8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
- {"/scalar-f32-u8-u8", test_scalar_f32_u8_u8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
- {"/scalar-f64-f32-f32", test_scalar_f64_f32_f32, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
- {"/scalar-f64-f32-f64", test_scalar_f64_f32_f64, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
- {"/scalar-f64-f32-i8", test_scalar_f64_f32_i8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
- {"/scalar-f64-f32-u8", test_scalar_f64_f32_u8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
- {"/scalar-f64-f64-f32", test_scalar_f64_f64_f32, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
- {"/scalar-f64-f64-f64", test_scalar_f64_f64_f64, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
- {"/scalar-f64-f64-i8", test_scalar_f64_f64_i8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
- {"/scalar-f64-f64-u8", test_scalar_f64_f64_u8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
- {"/scalar-f64-i8-f32", test_scalar_f64_i8_f32, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
- {"/scalar-f64-i8-f64", test_scalar_f64_i8_f64, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
- {"/scalar-f64-i8-i8", test_scalar_f64_i8_i8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
- {"/scalar-f64-i8-u8", test_scalar_f64_i8_u8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
- {"/scalar-f64-u8-f32", test_scalar_f64_u8_f32, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
- {"/scalar-f64-u8-f64", test_scalar_f64_u8_f64, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
- {"/scalar-f64-u8-i8", test_scalar_f64_u8_i8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
- {"/scalar-f64-u8-u8", test_scalar_f64_u8_u8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
- {"/scalar-i8-f32-f32", test_scalar_i8_f32_f32, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
- {"/scalar-i8-f32-f64", test_scalar_i8_f32_f64, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
- {"/scalar-i8-f32-i8", test_scalar_i8_f32_i8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
- {"/scalar-i8-f32-u8", test_scalar_i8_f32_u8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
- {"/scalar-i8-f64-f32", test_scalar_i8_f64_f32, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
- {"/scalar-i8-f64-f64", test_scalar_i8_f64_f64, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
- {"/scalar-i8-f64-i8", test_scalar_i8_f64_i8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
- {"/scalar-i8-f64-u8", test_scalar_i8_f64_u8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
- {"/scalar-i8-i8-f32", test_scalar_i8_i8_f32, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
- {"/scalar-i8-i8-f64", test_scalar_i8_i8_f64, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
- {"/scalar-i8-i8-i8", test_scalar_i8_i8_i8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
- {"/scalar-i8-i8-u8", test_scalar_i8_i8_u8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
- {"/scalar-i8-u8-f32", test_scalar_i8_u8_f32, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
- {"/scalar-i8-u8-f64", test_scalar_i8_u8_f64, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
- {"/scalar-i8-u8-i8", test_scalar_i8_u8_i8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
- {"/scalar-i8-u8-u8", test_scalar_i8_u8_u8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
- {"/scalar-u8-f32-f32", test_scalar_u8_f32_f32, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
- {"/scalar-u8-f32-f64", test_scalar_u8_f32_f64, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
- {"/scalar-u8-f32-i8", test_scalar_u8_f32_i8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
- {"/scalar-u8-f32-u8", test_scalar_u8_f32_u8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
- {"/scalar-u8-f64-f32", test_scalar_u8_f64_f32, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
- {"/scalar-u8-f64-f64", test_scalar_u8_f64_f64, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
- {"/scalar-u8-f64-i8", test_scalar_u8_f64_i8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
- {"/scalar-u8-f64-u8", test_scalar_u8_f64_u8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
- {"/scalar-u8-i8-f32", test_scalar_u8_i8_f32, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
- {"/scalar-u8-i8-f64", test_scalar_u8_i8_f64, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
- {"/scalar-u8-i8-i8", test_scalar_u8_i8_i8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
{"/scalar-u8-i8-u8", test_scalar_u8_i8_u8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
- {"/scalar-u8-u8-f32", test_scalar_u8_u8_f32, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
- {"/scalar-u8-u8-f64", test_scalar_u8_u8_f64, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
- {"/scalar-u8-u8-i8", test_scalar_u8_u8_i8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
- {"/scalar-u8-u8-u8", test_scalar_u8_u8_u8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
- {NULL, NULL, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}
-};
+ {"/scalar-u8-i8-u8-medium", test_scalar_u8_i8_u8_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
+#ifdef __AVX512VNNI__
+ {"/avx512vnni-u8-i8-u8", test_avx512vnni_u8_i8_u8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
+ {"/avx512vnni-u8-i8-u8-medium", test_avx512vnni_u8_i8_u8_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
+#endif
+ {"/dispatched-u8-i8-u8", test_dispatched_u8_i8_u8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
+ {"/dispatched-u8-i8-u8-medium", test_dispatched_u8_i8_u8_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
+ {NULL, NULL, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}};
-static const MunitSuite suite = { "/matmul", tests, NULL, 1, MUNIT_SUITE_OPTION_NONE };
+static const MunitSuite suite = {"/matmul", tests, NULL, 1, MUNIT_SUITE_OPTION_NONE};
int main(int argc, char *argv[MUNIT_ARRAY_PARAM(argc)]) {
- return munit_suite_main(&suite, NULL, argc, argv);
+ return munit_suite_main(&suite, NULL, argc, argv);
}
diff --git a/test/test_matmul_simd.h b/test/test_matmul_simd.h
@@ -41,6 +41,10 @@
#include <stddef.h>
#include <stdint.h>
+#ifdef __cplusplus
+extern "C" {
+#endif
+
#define MATMUL_SIMD_VERSION "1.0.0"
#define MATMUL_FLAG_SCALAR (1 << 0)
@@ -54,128 +58,16 @@ typedef uint32_t matmul_feature_t;
matmul_feature_t matmul_get_feature(void);
const char *matmul_get_feature_name(matmul_feature_t feat);
-int matmul_scalar_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const float *B, float *C, double scale);
-int matmul_scalar_f32_f32_f64(size_t m, size_t n, size_t p, const float *A, const float *B, double *C, double scale);
-int matmul_scalar_f32_f32_i8(size_t m, size_t n, size_t p, const float *A, const float *B, int8_t *C, double scale);
-int matmul_scalar_f32_f32_u8(size_t m, size_t n, size_t p, const float *A, const float *B, uint8_t *C, double scale);
-int matmul_scalar_f32_f64_f32(size_t m, size_t n, size_t p, const float *A, const double *B, float *C, double scale);
-int matmul_scalar_f32_f64_f64(size_t m, size_t n, size_t p, const float *A, const double *B, double *C, double scale);
-int matmul_scalar_f32_f64_i8(size_t m, size_t n, size_t p, const float *A, const double *B, int8_t *C, double scale);
-int matmul_scalar_f32_f64_u8(size_t m, size_t n, size_t p, const float *A, const double *B, uint8_t *C, double scale);
-int matmul_scalar_f32_i8_f32(size_t m, size_t n, size_t p, const float *A, const int8_t *B, float *C, double scale);
-int matmul_scalar_f32_i8_f64(size_t m, size_t n, size_t p, const float *A, const int8_t *B, double *C, double scale);
-int matmul_scalar_f32_i8_i8(size_t m, size_t n, size_t p, const float *A, const int8_t *B, int8_t *C, double scale);
-int matmul_scalar_f32_i8_u8(size_t m, size_t n, size_t p, const float *A, const int8_t *B, uint8_t *C, double scale);
-int matmul_scalar_f32_u8_f32(size_t m, size_t n, size_t p, const float *A, const uint8_t *B, float *C, double scale);
-int matmul_scalar_f32_u8_f64(size_t m, size_t n, size_t p, const float *A, const uint8_t *B, double *C, double scale);
-int matmul_scalar_f32_u8_i8(size_t m, size_t n, size_t p, const float *A, const uint8_t *B, int8_t *C, double scale);
-int matmul_scalar_f32_u8_u8(size_t m, size_t n, size_t p, const float *A, const uint8_t *B, uint8_t *C, double scale);
-int matmul_scalar_f64_f32_f32(size_t m, size_t n, size_t p, const double *A, const float *B, float *C, double scale);
-int matmul_scalar_f64_f32_f64(size_t m, size_t n, size_t p, const double *A, const float *B, double *C, double scale);
-int matmul_scalar_f64_f32_i8(size_t m, size_t n, size_t p, const double *A, const float *B, int8_t *C, double scale);
-int matmul_scalar_f64_f32_u8(size_t m, size_t n, size_t p, const double *A, const float *B, uint8_t *C, double scale);
-int matmul_scalar_f64_f64_f32(size_t m, size_t n, size_t p, const double *A, const double *B, float *C, double scale);
-int matmul_scalar_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, const double *B, double *C, double scale);
-int matmul_scalar_f64_f64_i8(size_t m, size_t n, size_t p, const double *A, const double *B, int8_t *C, double scale);
-int matmul_scalar_f64_f64_u8(size_t m, size_t n, size_t p, const double *A, const double *B, uint8_t *C, double scale);
-int matmul_scalar_f64_i8_f32(size_t m, size_t n, size_t p, const double *A, const int8_t *B, float *C, double scale);
-int matmul_scalar_f64_i8_f64(size_t m, size_t n, size_t p, const double *A, const int8_t *B, double *C, double scale);
-int matmul_scalar_f64_i8_i8(size_t m, size_t n, size_t p, const double *A, const int8_t *B, int8_t *C, double scale);
-int matmul_scalar_f64_i8_u8(size_t m, size_t n, size_t p, const double *A, const int8_t *B, uint8_t *C, double scale);
-int matmul_scalar_f64_u8_f32(size_t m, size_t n, size_t p, const double *A, const uint8_t *B, float *C, double scale);
-int matmul_scalar_f64_u8_f64(size_t m, size_t n, size_t p, const double *A, const uint8_t *B, double *C, double scale);
-int matmul_scalar_f64_u8_i8(size_t m, size_t n, size_t p, const double *A, const uint8_t *B, int8_t *C, double scale);
-int matmul_scalar_f64_u8_u8(size_t m, size_t n, size_t p, const double *A, const uint8_t *B, uint8_t *C, double scale);
-int matmul_scalar_i8_f32_f32(size_t m, size_t n, size_t p, const int8_t *A, const float *B, float *C, double scale);
-int matmul_scalar_i8_f32_f64(size_t m, size_t n, size_t p, const int8_t *A, const float *B, double *C, double scale);
-int matmul_scalar_i8_f32_i8(size_t m, size_t n, size_t p, const int8_t *A, const float *B, int8_t *C, double scale);
-int matmul_scalar_i8_f32_u8(size_t m, size_t n, size_t p, const int8_t *A, const float *B, uint8_t *C, double scale);
-int matmul_scalar_i8_f64_f32(size_t m, size_t n, size_t p, const int8_t *A, const double *B, float *C, double scale);
-int matmul_scalar_i8_f64_f64(size_t m, size_t n, size_t p, const int8_t *A, const double *B, double *C, double scale);
-int matmul_scalar_i8_f64_i8(size_t m, size_t n, size_t p, const int8_t *A, const double *B, int8_t *C, double scale);
-int matmul_scalar_i8_f64_u8(size_t m, size_t n, size_t p, const int8_t *A, const double *B, uint8_t *C, double scale);
-int matmul_scalar_i8_i8_f32(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, float *C, double scale);
-int matmul_scalar_i8_i8_f64(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, double *C, double scale);
-int matmul_scalar_i8_i8_i8(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, int8_t *C, double scale);
-int matmul_scalar_i8_i8_u8(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, uint8_t *C, double scale);
-int matmul_scalar_i8_u8_f32(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, float *C, double scale);
-int matmul_scalar_i8_u8_f64(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, double *C, double scale);
-int matmul_scalar_i8_u8_i8(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, int8_t *C, double scale);
-int matmul_scalar_i8_u8_u8(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, uint8_t *C, double scale);
-int matmul_scalar_u8_f32_f32(size_t m, size_t n, size_t p, const uint8_t *A, const float *B, float *C, double scale);
-int matmul_scalar_u8_f32_f64(size_t m, size_t n, size_t p, const uint8_t *A, const float *B, double *C, double scale);
-int matmul_scalar_u8_f32_i8(size_t m, size_t n, size_t p, const uint8_t *A, const float *B, int8_t *C, double scale);
-int matmul_scalar_u8_f32_u8(size_t m, size_t n, size_t p, const uint8_t *A, const float *B, uint8_t *C, double scale);
-int matmul_scalar_u8_f64_f32(size_t m, size_t n, size_t p, const uint8_t *A, const double *B, float *C, double scale);
-int matmul_scalar_u8_f64_f64(size_t m, size_t n, size_t p, const uint8_t *A, const double *B, double *C, double scale);
-int matmul_scalar_u8_f64_i8(size_t m, size_t n, size_t p, const uint8_t *A, const double *B, int8_t *C, double scale);
-int matmul_scalar_u8_f64_u8(size_t m, size_t n, size_t p, const uint8_t *A, const double *B, uint8_t *C, double scale);
-int matmul_scalar_u8_i8_f32(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, float *C, double scale);
-int matmul_scalar_u8_i8_f64(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, double *C, double scale);
-int matmul_scalar_u8_i8_i8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, int8_t *C, double scale);
int matmul_scalar_u8_i8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, uint8_t *C, double scale);
-int matmul_scalar_u8_u8_f32(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, float *C, double scale);
-int matmul_scalar_u8_u8_f64(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, double *C, double scale);
-int matmul_scalar_u8_u8_i8(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, int8_t *C, double scale);
-int matmul_scalar_u8_u8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, uint8_t *C, double scale);
-
-#ifdef __cplusplus
-extern "C" {
-#endif
#ifdef __AVX2__
#include <immintrin.h>
-int matmul_avx2_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const float *B, float *C, double scale);
-int matmul_avx2_f32_f32_f64(size_t m, size_t n, size_t p, const float *A, const float *B, double *C, double scale);
-int matmul_avx2_f32_f64_f32(size_t m, size_t n, size_t p, const float *A, const double *B, float *C, double scale);
-int matmul_avx2_f32_f64_f64(size_t m, size_t n, size_t p, const float *A, const double *B, double *C, double scale);
-int matmul_avx2_f64_f32_f32(size_t m, size_t n, size_t p, const double *A, const float *B, float *C, double scale);
-int matmul_avx2_f64_f32_f64(size_t m, size_t n, size_t p, const double *A, const float *B, double *C, double scale);
-int matmul_avx2_f64_f64_f32(size_t m, size_t n, size_t p, const double *A, const double *B, float *C, double scale);
-int matmul_avx2_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, const double *B, double *C, double scale);
-int matmul_avx2_i8_i8_f32(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, float *C, double scale);
-int matmul_avx2_i8_i8_f64(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, double *C, double scale);
-int matmul_avx2_i8_i8_i8(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, int8_t *C, double scale);
-int matmul_avx2_i8_i8_u8(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, uint8_t *C, double scale);
-int matmul_avx2_i8_u8_f32(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, float *C, double scale);
-int matmul_avx2_i8_u8_f64(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, double *C, double scale);
-int matmul_avx2_i8_u8_i8(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, int8_t *C, double scale);
-int matmul_avx2_i8_u8_u8(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, uint8_t *C, double scale);
-int matmul_avx2_u8_i8_f32(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, float *C, double scale);
-int matmul_avx2_u8_i8_f64(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, double *C, double scale);
-int matmul_avx2_u8_i8_i8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, int8_t *C, double scale);
int matmul_avx2_u8_i8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, uint8_t *C, double scale);
-int matmul_avx2_u8_u8_f32(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, float *C, double scale);
-int matmul_avx2_u8_u8_f64(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, double *C, double scale);
-int matmul_avx2_u8_u8_i8(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, int8_t *C, double scale);
-int matmul_avx2_u8_u8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, uint8_t *C, double scale);
#endif
-#ifdef __AVX512F__
-int matmul_avx512_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const float *B, float *C, double scale);
-int matmul_avx512_f32_f32_f64(size_t m, size_t n, size_t p, const float *A, const float *B, double *C, double scale);
-int matmul_avx512_f32_f64_f32(size_t m, size_t n, size_t p, const float *A, const double *B, float *C, double scale);
-int matmul_avx512_f32_f64_f64(size_t m, size_t n, size_t p, const float *A, const double *B, double *C, double scale);
-int matmul_avx512_f64_f32_f32(size_t m, size_t n, size_t p, const double *A, const float *B, float *C, double scale);
-int matmul_avx512_f64_f32_f64(size_t m, size_t n, size_t p, const double *A, const float *B, double *C, double scale);
-int matmul_avx512_f64_f64_f32(size_t m, size_t n, size_t p, const double *A, const double *B, float *C, double scale);
-int matmul_avx512_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, const double *B, double *C, double scale);
-int matmul_avx512_i8_i8_f32(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, float *C, double scale);
-int matmul_avx512_i8_i8_f64(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, double *C, double scale);
-int matmul_avx512_i8_i8_i8(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, int8_t *C, double scale);
-int matmul_avx512_i8_i8_u8(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, uint8_t *C, double scale);
-int matmul_avx512_i8_u8_f32(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, float *C, double scale);
-int matmul_avx512_i8_u8_f64(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, double *C, double scale);
-int matmul_avx512_i8_u8_i8(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, int8_t *C, double scale);
-int matmul_avx512_i8_u8_u8(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, uint8_t *C, double scale);
-int matmul_avx512_u8_i8_f32(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, float *C, double scale);
-int matmul_avx512_u8_i8_f64(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, double *C, double scale);
-int matmul_avx512_u8_i8_i8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, int8_t *C, double scale);
-int matmul_avx512_u8_i8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, uint8_t *C, double scale);
-int matmul_avx512_u8_u8_f32(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, float *C, double scale);
-int matmul_avx512_u8_u8_f64(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, double *C, double scale);
-int matmul_avx512_u8_u8_i8(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, int8_t *C, double scale);
-int matmul_avx512_u8_u8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, uint8_t *C, double scale);
+#ifdef __AVX512VNNI__
+int matmul_avx512vnni_u8_i8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, uint8_t *C,
+ double scale);
#endif
#ifdef __cplusplus