matmul.c

Matrix multiplication helper library
git clone git://git.finwo.net/lib/matmul.c
Log | Files | Refs | README | LICENSE

commit 4615e6a5d18b3ee05831fab834628d51c011a6d9
parent 11f476190e1bacace411948da4e194c3f6080a93
Author: finwo <finwo@pm.me>
Date:   Sat, 18 Apr 2026 18:45:29 +0200

Functional

Diffstat:
M.gitignore | 1+
MREADME.md | 89++++++++++++++++++++++++++++++++++++++++++++++---------------------------------
Msrc/matmul.c | 3468++++---------------------------------------------------------------------------
Mtest/benchmark.c | 212++++++++++++++++++++++++++++++-------------------------------------------------
Mtest/test_matmul.c | 424+++++++++++++++++++++++--------------------------------------------------------
Mtest/test_matmul_simd.h | 122+++++--------------------------------------------------------------------------
6 files changed, 417 insertions(+), 3899 deletions(-)

diff --git a/.gitignore b/.gitignore @@ -2,3 +2,4 @@ *.o /test_matmul /.dep +/benchmark diff --git a/README.md b/README.md @@ -1,6 +1,9 @@ -# Matmul - High-Performance Matrix Multiplication Library +# Matmul - Accelerated Matrix Multiplication Library -A lightweight, type-safe matrix multiplication library with compile-time dispatch, scaling support, and infrastructure for SIMD acceleration. +A lightweight, type-safe matrix multiplication library with runtime dispatch and SIMD acceleration. + +**Current Implementation:** `uint8_t` × `int8_t` → `uint8_t` +*(additional type combinations planned)* ## Installation @@ -14,12 +17,11 @@ Alternatively, you can include the [matmul.c](src/matmul.c) and [matmul.h](src/m ## Features -- **64 Type Combinations**: Supports all combinations of f32/f64/i8/u8 for input matrices A, B and output C -- **Compile-Time Dispatch**: Uses C11 `_Generic` for zero-overhead type detection -- **Scale Parameter**: Divide output before writing (e.g., scale=64 for quantization emulation) -- **Future SIMD Ready**: Infrastructure in place for AVX2/AVX512/VNNI acceleration +- **SIMD Acceleration**: AVX2, AVX512, AVX-VNNI, and AVX512-VNNI with automatic CPU feature detection +- **Multi-Core Parallelization**: OpenMP-based parallel processing across matrix rows +- **Tiled Algorithm**: Cache-blocking for improved locality at large matrix sizes - **No Dependencies**: Pure C11 implementation with no external libraries -- **Well Tested**: 64 unit tests covering all type combinations +- **Well Tested**: Unit tests verify correctness across all implementations ## Quick Start @@ -28,25 +30,25 @@ Alternatively, you can include the [matmul.c](src/matmul.c) and [matmul.h](src/m #include <stdio.h> int main() { - // 2x3 matrix A - float A[6] = {1, 2, 3, 4, 5, 6}; - // 3x2 matrix B - float B[6] = {1, 0, 0, 1, 0, 0}; - // 2x2 result matrix C - float C[4]; + // 2x3 matrix A (uint8_t) + uint8_t A[6] = {1, 2, 3, 4, 5, 6}; + // 3x2 matrix B (int8_t) + int8_t B[6] = {1, 0, 0, 1, 0, 0}; + // 2x2 result matrix C (uint8_t) + uint8_t C[4]; // Multiply A(2x3) * B(3x2) = C(2x2) matmul(2, 3, 2, A, B, C, 0.0); // scale=0: no scaling // C should be [1, 2, 4, 5] - printf("C[0] = %f, C[1] = %f, C[2] = %f, C[3] = %f\n", + printf("C[0] = %u, C[1] = %u, C[2] = %u, C[3] = %u\n", C[0], C[1], C[2], C[3]); return 0; } ``` -Compile with: `cc -o example example.c -lm` +Compile with: `cc -o example example.c -lm -fopenmp` ## API Reference @@ -58,19 +60,17 @@ matmul(m, n, p, A, B, C, scale); - `scale`: Divide each output element by this value before writing (0 or 1 = no scaling) ### Direct Function Calls -Each of the 64 type combinations is available directly: ```c -// Floating point -matmul_f32_f32_f32(m, n, p, A, B, C, scale); -matmul_f32_f32_f64(m, n, p, A, B, C, scale); -matmul_f32_f64_f32(m, n, p, A, B, C, scale); -// ... etc for all 64 combinations - -// Integer types -matmul_i8_i8_i8(m, n, p, A, B, C, scale); -matmul_u8_u8_u8(m, n, p, A, B, C, scale); -matmul_i8_u8_i8(m, n, p, A, B, C, scale); -// ... etc +// Currently implemented type combination (u8 × i8 → u8) +int matmul_u8_i8_u8(size_t m, size_t n, size_t p, + const uint8_t *A, const int8_t *B, + uint8_t *C, double scale); + +// Scalar and SIMD variants +int matmul_scalar_u8_i8_u8(...); +int matmul_avx2_u8_i8_u8(...); +int matmul_avx512_u8_i8_u8(...); +int matmul_avxvnni_u8_i8_u8(...); ``` ### Type Naming Conventions @@ -105,30 +105,45 @@ make # Run tests ./test_matmul +# Run benchmarks (optional, needs ~800MB RAM for 16K×16K) +./benchmark + # Clean build artifacts make clean ``` -Requires a C11 compiler (gcc, clang, MSVC). +Requires a C11 compiler (gcc, clang, MSVC) with OpenMP support. + +## Performance + +On an AMD Ryzen 7 9800X3D (Zen 5, 16 cores), with AVX2/AVX-VNNI/AVX512/AVX512-VNNI enabled: + +| Matrix Size | Time (ms) | GFLOPS | +|-------------|-----------|---------| +| 1024×1024 | ~2.4 | ~880 | +| 4096×4096 | ~173 | ~795 | +| 16384×16384 | ~12970 | ~678 | + +The 16384×16384 case exceeds L3 cache (~768MB total), so performance is memory-bound. For optimal performance use matrices that fit in L3 (≤4096 on most CPUs). ## Testing -The library includes 64 unit tests covering all type combinations: +The library includes unit tests verifying correctness across all implementations: - Run: `./test_matmul` -- Tests verify correctness against reference implementations -- Output shows PASS/FAIL status for each type combination +- Tests verify correctness against reference scalar implementation +- Output shows PASS/FAIL status for each implementation (scalar, AVX2, AVX512, dispatched) -## Future Work +## Implementation Notes -SIMD acceleration infrastructure is already in place: -- Auto-dispatch functions (`_matmul_*`) replace function pointers on first call -- Ready for AVX2, AVX512, AVX512-VNNI, and AVX-VNNI implementations -- When implemented, dispatch will select best available CPU features at runtime -- Fallback chain: AVX512-VNNI → AVX512 → AVX2 → Scalar +- **Automatic dispatch**: The first call runtime-detects CPU features and selects the optimal implementation +- **Dispatch priority**: AVX512-VNNI → AVX512 → AVX-VNNI → AVX2 → Scalar +- **Parallelization**: OpenMP `parallel for` with `static` scheduling across row blocks +- **Tiling**: Blocking factors tuned for L1/L2 cache (ib=32/64, jb=64, kb=32/64 depending on SIMD width) ## License Licensed under custom terms (Copyright 2026 finwo); see LICENSE.md for full details. --- + *Built with C11. Zero runtime overhead for type dispatch.* diff --git a/src/matmul.c b/src/matmul.c @@ -42,3371 +42,209 @@ #include <stdlib.h> #include <string.h> +#ifdef __AVX2__ +#include <immintrin.h> +#endif + +#ifdef __AVX512F__ +#include <immintrin.h> +#endif + +#ifdef __AVXVNNI__ +#include <immintrin.h> +#endif + +#ifdef __AVX512VNNI__ +#include <immintrin.h> +#endif + #ifdef _OPENMP #include <omp.h> #endif -#define MATMUL_FLAG_SCALAR (1 << 0) -#define MATMUL_FLAG_AVX2 (1 << 1) -#define MATMUL_FLAG_AVX512 (1 << 2) -#define MATMUL_FLAG_AVX512_VNNI (1 << 3) -#define MATMUL_FLAG_AVXVNNI (1 << 4) +#define MATMUL_FLAG_SCALAR (1 << 0) +#define MATMUL_FLAG_AVX2 (1 << 1) +#define MATMUL_FLAG_AVXVNNI (1 << 2) +#define MATMUL_FLAG_AVX512 (1 << 3) +#define MATMUL_FLAG_AVX512VNNI (1 << 4) typedef uint32_t matmul_feature_t; -static matmul_feature_t g_feature = 0; +static matmul_feature_t g_feature = 0; +static int g_initialized = 0; static void init_feature(void) { - if (g_feature != 0) return; g_feature = MATMUL_FLAG_SCALAR; -#ifdef __AVX2__ - if (__builtin_cpu_supports("avx2")) g_feature |= MATMUL_FLAG_AVX2; +#ifdef __AVX512VNNI__ + if (__builtin_cpu_supports("avx512vnni")) g_feature |= MATMUL_FLAG_AVX512VNNI; #endif #ifdef __AVX512F__ - if (__builtin_cpu_supports("avx512f")) { - g_feature |= MATMUL_FLAG_AVX512; - if (__builtin_cpu_supports("avx512vnni")) g_feature |= MATMUL_FLAG_AVX512_VNNI; - } + if (__builtin_cpu_supports("avx512f")) g_feature |= MATMUL_FLAG_AVX512; +#endif +#ifdef __AVXVNNI__ + if (__builtin_cpu_supports("avxvnni")) g_feature |= MATMUL_FLAG_AVXVNNI; #endif #ifdef __AVX2__ - if (__builtin_cpu_supports("avx2") && __builtin_cpu_supports("avxvnni")) g_feature |= MATMUL_FLAG_AVXVNNI; + if (__builtin_cpu_supports("avx2")) g_feature |= MATMUL_FLAG_AVX2; #endif } matmul_feature_t matmul_get_feature(void) { - init_feature(); + if (!g_initialized) { + init_feature(); + g_initialized = 1; + } return g_feature; } const char *matmul_get_feature_name(matmul_feature_t feat) { + if (feat & MATMUL_FLAG_AVX512VNNI) return "avx512vnni"; if (feat & MATMUL_FLAG_AVX512) return "avx512"; + if (feat & MATMUL_FLAG_AVXVNNI) return "avxvnni"; if (feat & MATMUL_FLAG_AVX2) return "avx2"; - if (feat & MATMUL_FLAG_AVX512_VNNI) return "avx512_vnni"; - if (feat & MATMUL_FLAG_AVXVNNI) return "avx_vnni"; if (feat & MATMUL_FLAG_SCALAR) return "scalar"; return "unknown"; } -int matmul_scalar_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const float *B, float *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - float sum = 0; - for (size_t k = 0; k < n; k++) { - sum += A[i * n + k] * B[k * p + j]; - } - if (scale != 0 && scale != 1) sum /= (float)scale; - C[i * p + j] = sum; - } - } - return 0; -} - -int matmul_scalar_f32_f32_f64(size_t m, size_t n, size_t p, const float *A, const float *B, double *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - double sum = 0; - for (size_t k = 0; k < n; k++) { - sum += (double)A[i * n + k] * (double)B[k * p + j]; - } - if (scale != 0 && scale != 1) sum /= scale; - C[i * p + j] = sum; - } - } - return 0; -} - -int matmul_scalar_f32_f32_i8(size_t m, size_t n, size_t p, const float *A, const float *B, int8_t *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - int sum = 0; - for (size_t k = 0; k < n; k++) { - sum += (int)A[i * n + k] * (int)B[k * p + j]; - } - if (scale != 0 && scale != 1) sum = (int)(sum / scale); - if (sum > 127) - sum = 127; - else if (sum < -128) - sum = -128; - C[i * p + j] = (int8_t)sum; - } - } - return 0; -} - -int matmul_scalar_f32_f32_u8(size_t m, size_t n, size_t p, const float *A, const float *B, uint8_t *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - int sum = 0; - for (size_t k = 0; k < n; k++) { - sum += (int)A[i * n + k] * (int)B[k * p + j]; - } - if (scale != 0 && scale != 1) sum = (int)(sum / scale); - if (sum > 255) - sum = 255; - else if (sum < 0) - sum = 0; - C[i * p + j] = (uint8_t)sum; - } - } - return 0; -} - -int matmul_scalar_f32_f64_f32(size_t m, size_t n, size_t p, const float *A, const double *B, float *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - float sum = 0; - for (size_t k = 0; k < n; k++) { - sum += (float)A[i * n + k] * (float)B[k * p + j]; - } - if (scale != 0 && scale != 1) sum /= (float)scale; - C[i * p + j] = sum; - } - } - return 0; -} - -int matmul_scalar_f32_f64_f64(size_t m, size_t n, size_t p, const float *A, const double *B, double *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - double sum = 0; - for (size_t k = 0; k < n; k++) { - sum += (double)A[i * n + k] * B[k * p + j]; - } - if (scale != 0 && scale != 1) sum /= scale; - C[i * p + j] = sum; - } - } - return 0; -} - -int matmul_scalar_f32_f64_i8(size_t m, size_t n, size_t p, const float *A, const double *B, int8_t *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - int sum = 0; - for (size_t k = 0; k < n; k++) { - sum += (int)A[i * n + k] * (int)B[k * p + j]; - } - if (scale != 0 && scale != 1) sum = (int)(sum / scale); - if (sum > 127) - sum = 127; - else if (sum < -128) - sum = -128; - C[i * p + j] = (int8_t)sum; - } - } - return 0; -} - -int matmul_scalar_f32_f64_u8(size_t m, size_t n, size_t p, const float *A, const double *B, uint8_t *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - int sum = 0; - for (size_t k = 0; k < n; k++) { - sum += (int)A[i * n + k] * (int)B[k * p + j]; - } - if (scale != 0 && scale != 1) sum = (int)(sum / scale); - if (sum > 255) - sum = 255; - else if (sum < 0) - sum = 0; - C[i * p + j] = (uint8_t)sum; - } - } - return 0; -} - -int matmul_scalar_f32_i8_f32(size_t m, size_t n, size_t p, const float *A, const int8_t *B, float *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - float sum = 0; - for (size_t k = 0; k < n; k++) { - sum += (float)A[i * n + k] * (float)B[k * p + j]; - } - if (scale != 0 && scale != 1) sum /= (float)scale; - C[i * p + j] = sum; - } - } - return 0; -} - -int matmul_scalar_f32_i8_f64(size_t m, size_t n, size_t p, const float *A, const int8_t *B, double *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - double sum = 0; - for (size_t k = 0; k < n; k++) { - sum += (double)A[i * n + k] * (double)B[k * p + j]; - } - if (scale != 0 && scale != 1) sum /= scale; - C[i * p + j] = sum; - } - } - return 0; -} - -int matmul_scalar_f32_i8_i8(size_t m, size_t n, size_t p, const float *A, const int8_t *B, int8_t *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - int sum = 0; - for (size_t k = 0; k < n; k++) { - sum += (int)A[i * n + k] * (int)B[k * p + j]; - } - if (scale != 0 && scale != 1) sum = (int)(sum / scale); - if (sum > 127) - sum = 127; - else if (sum < -128) - sum = -128; - C[i * p + j] = (int8_t)sum; - } - } - return 0; -} - -int matmul_scalar_f32_i8_u8(size_t m, size_t n, size_t p, const float *A, const int8_t *B, uint8_t *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - int sum = 0; - for (size_t k = 0; k < n; k++) { - sum += (int)A[i * n + k] * (int)B[k * p + j]; - } - if (scale != 0 && scale != 1) sum = (int)(sum / scale); - if (sum > 255) - sum = 255; - else if (sum < 0) - sum = 0; - C[i * p + j] = (uint8_t)sum; - } - } - return 0; -} - -int matmul_scalar_f32_u8_f32(size_t m, size_t n, size_t p, const float *A, const uint8_t *B, float *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - float sum = 0; - for (size_t k = 0; k < n; k++) { - sum += (float)A[i * n + k] * (float)B[k * p + j]; - } - if (scale != 0 && scale != 1) sum /= (float)scale; - C[i * p + j] = sum; - } - } - return 0; -} - -int matmul_scalar_f32_u8_f64(size_t m, size_t n, size_t p, const float *A, const uint8_t *B, double *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - double sum = 0; - for (size_t k = 0; k < n; k++) { - sum += (double)A[i * n + k] * (double)B[k * p + j]; - } - if (scale != 0 && scale != 1) sum /= scale; - C[i * p + j] = sum; - } - } - return 0; -} - -int matmul_scalar_f32_u8_i8(size_t m, size_t n, size_t p, const float *A, const uint8_t *B, int8_t *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - int sum = 0; - for (size_t k = 0; k < n; k++) { - sum += (int)A[i * n + k] * (int)B[k * p + j]; - } - if (scale != 0 && scale != 1) sum = (int)(sum / scale); - if (sum > 127) - sum = 127; - else if (sum < -128) - sum = -128; - C[i * p + j] = (int8_t)sum; - } - } - return 0; -} +int matmul_scalar_u8_i8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, uint8_t *C, double scale) { + (void)scale; + const size_t ib = 64; + const size_t jb = 64; + const size_t kb = 32; + +#pragma omp parallel for schedule(static) + for (size_t ii = 0; ii < m; ii += ib) { + size_t i_end = (ii + ib < m) ? ii + ib : m; + for (size_t jj = 0; jj < p; jj += jb) { + size_t j_end = (jj + jb < p) ? jj + jb : p; + size_t ti = i_end - ii; + size_t tj = j_end - jj; + int32_t acc[64 * 64]; + memset(acc, 0, ti * tj * sizeof(int32_t)); + + for (size_t kk = 0; kk < n; kk += kb) { + size_t k_end = (kk + kb < n) ? kk + kb : n; + for (size_t i = ii; i < i_end; i++) { + size_t li = i - ii; + for (size_t j = jj; j < j_end; j++) { + size_t lj = j - jj; + int sum = 0; + for (size_t k = kk; k < k_end; k++) { + sum += (int)A[i * n + k] * (int)B[k * p + j]; + } + acc[li * tj + lj] += sum; + } + } + } + + for (size_t i = ii; i < i_end; i++) { + size_t li = i - ii; + for (size_t j = jj; j < j_end; j++) { + size_t lj = j - jj; + int v = acc[li * tj + lj]; + if (v > 255) + v = 255; + else if (v < 0) + v = 0; + C[i * p + j] = (uint8_t)v; + } + } + } + } + return 0; +} + +#ifdef __AVX512VNNI__ +int matmul_avx512vnni_u8_i8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, uint8_t *C, + double scale) { + (void)scale; + const size_t ib = 64; + const size_t jb = 64; + const size_t kb = 32; -int matmul_scalar_f32_u8_u8(size_t m, size_t n, size_t p, const float *A, const uint8_t *B, uint8_t *C, double scale) { -#ifdef _OPENMP #pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - int sum = 0; - for (size_t k = 0; k < n; k++) { - sum += (int)A[i * n + k] * (int)B[k * p + j]; - } - if (scale != 0 && scale != 1) sum = (int)(sum / scale); - if (sum > 255) - sum = 255; - else if (sum < 0) - sum = 0; - C[i * p + j] = (uint8_t)sum; - } - } - return 0; -} + for (size_t ii = 0; ii < m; ii += ib) { + size_t i_end = (ii + ib < m) ? ii + ib : m; + for (size_t jj = 0; jj < p; jj += jb) { + size_t j_end = (jj + jb < p) ? jj + jb : p; + size_t ti = i_end - ii; + size_t tj = j_end - jj; + int32_t acc[64 * 64]; + memset(acc, 0, ti * tj * sizeof(int32_t)); -int matmul_scalar_f64_f32_f32(size_t m, size_t n, size_t p, const double *A, const float *B, float *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - float sum = 0; - for (size_t k = 0; k < n; k++) { - sum += (float)A[i * n + k] * (float)B[k * p + j]; - } - if (scale != 0 && scale != 1) sum /= (float)scale; - C[i * p + j] = sum; - } - } - return 0; -} + for (size_t kk = 0; kk < n; kk += kb) { + size_t k_limit = (kk + kb < n) ? kk + kb : n; -int matmul_scalar_f64_f32_f64(size_t m, size_t n, size_t p, const double *A, const float *B, double *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - double sum = 0; - for (size_t k = 0; k < n; k++) { - sum += A[i * n + k] * (double)B[k * p + j]; - } - if (scale != 0 && scale != 1) sum /= scale; - C[i * p + j] = sum; - } - } - return 0; -} + for (size_t i = ii; i < i_end; i++) { + size_t li = i - ii; + int32_t *acc_row = &acc[li * tj]; -int matmul_scalar_f64_f32_i8(size_t m, size_t n, size_t p, const double *A, const float *B, int8_t *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - int sum = 0; - for (size_t k = 0; k < n; k++) { - sum += (int)A[i * n + k] * (int)B[k * p + j]; - } - if (scale != 0 && scale != 1) sum = (int)(sum / scale); - if (sum > 127) - sum = 127; - else if (sum < -128) - sum = -128; - C[i * p + j] = (int8_t)sum; - } - } - return 0; -} + for (size_t j = jj; j < j_end; j += 16) { + size_t j_chunk = (j + 16 <= j_end) ? 16 : (j_end - j); + __m512i result = _mm512_setzero_si512(); -int matmul_scalar_f64_f32_u8(size_t m, size_t n, size_t p, const double *A, const float *B, uint8_t *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - int sum = 0; - for (size_t k = 0; k < n; k++) { - sum += (int)A[i * n + k] * (int)B[k * p + j]; - } - if (scale != 0 && scale != 1) sum = (int)(sum / scale); - if (sum > 255) - sum = 255; - else if (sum < 0) - sum = 0; - C[i * p + j] = (uint8_t)sum; - } - } - return 0; -} + for (size_t k = kk; k < k_limit; k += 4) { + size_t k_chunk = (k + 4 <= k_limit) ? 4 : (k_limit - k); -int matmul_scalar_f64_f64_f32(size_t m, size_t n, size_t p, const double *A, const double *B, float *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - float sum = 0; - for (size_t k = 0; k < n; k++) { - sum += (float)A[i * n + k] * (float)B[k * p + j]; - } - if (scale != 0 && scale != 1) sum /= (float)scale; - C[i * p + j] = sum; - } - } - return 0; -} + uint32_t a4 = 0; + for (size_t dk = 0; dk < k_chunk; dk++) { + a4 |= (uint32_t)A[i * n + k + dk] << (dk * 8); + } + __m512i a_val = _mm512_set1_epi32(a4); -int matmul_scalar_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, const double *B, double *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - double sum = 0; - for (size_t k = 0; k < n; k++) { - sum += A[i * n + k] * B[k * p + j]; - } - if (scale != 0 && scale != 1) sum /= scale; - C[i * p + j] = sum; - } - } - return 0; -} + int8_t b_buf[64] = {0}; + for (size_t dj = 0; dj < j_chunk; dj++) { + for (size_t dk = 0; dk < k_chunk; dk++) { + b_buf[dj * 4 + dk] = B[(k + dk) * p + j + dj]; + } + } + __m512i b_val = _mm512_load_si512((__m512i const *)b_buf); -int matmul_scalar_f64_f64_i8(size_t m, size_t n, size_t p, const double *A, const double *B, int8_t *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - int sum = 0; - for (size_t k = 0; k < n; k++) { - sum += (int)A[i * n + k] * (int)B[k * p + j]; - } - if (scale != 0 && scale != 1) sum = (int)(sum / scale); - if (sum > 127) - sum = 127; - else if (sum < -128) - sum = -128; - C[i * p + j] = (int8_t)sum; - } - } - return 0; -} + result = _mm512_dpbusd_epi32(result, a_val, b_val); + } -int matmul_scalar_f64_f64_u8(size_t m, size_t n, size_t p, const double *A, const double *B, uint8_t *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - int sum = 0; - for (size_t k = 0; k < n; k++) { - sum += (int)A[i * n + k] * (int)B[k * p + j]; - } - if (scale != 0 && scale != 1) sum = (int)(sum / scale); - if (sum > 255) - sum = 255; - else if (sum < 0) - sum = 0; - C[i * p + j] = (uint8_t)sum; - } - } - return 0; -} + int32_t tmp[16] __attribute__((aligned(64))); + _mm512_store_si512(tmp, result); -int matmul_scalar_f64_i8_f32(size_t m, size_t n, size_t p, const double *A, const int8_t *B, float *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - float sum = 0; - for (size_t k = 0; k < n; k++) { - sum += (float)A[i * n + k] * (float)B[k * p + j]; + size_t j_offset = j - jj; + for (size_t c = 0; c < j_chunk; c++) { + acc_row[j_offset + c] += tmp[c]; + } + } + } } - if (scale != 0 && scale != 1) sum /= (float)scale; - C[i * p + j] = sum; - } - } - return 0; -} -int matmul_scalar_f64_i8_f64(size_t m, size_t n, size_t p, const double *A, const int8_t *B, double *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - double sum = 0; - for (size_t k = 0; k < n; k++) { - sum += A[i * n + k] * (double)B[k * p + j]; + for (size_t i = ii; i < i_end; i++) { + size_t li = i - ii; + for (size_t j = jj; j < j_end; j++) { + size_t lj = j - jj; + int32_t v = acc[li * tj + lj]; + if (v > 255) + v = 255; + else if (v < 0) + v = 0; + C[i * p + j] = (uint8_t)v; + } } - if (scale != 0 && scale != 1) sum /= scale; - C[i * p + j] = sum; } } return 0; } - -int matmul_scalar_f64_i8_i8(size_t m, size_t n, size_t p, const double *A, const int8_t *B, int8_t *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) #endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - int sum = 0; - for (size_t k = 0; k < n; k++) { - sum += (int)A[i * n + k] * (int)B[k * p + j]; - } - if (scale != 0 && scale != 1) sum = (int)(sum / scale); - if (sum > 127) - sum = 127; - else if (sum < -128) - sum = -128; - C[i * p + j] = (int8_t)sum; - } - } - return 0; -} -int matmul_scalar_f64_i8_u8(size_t m, size_t n, size_t p, const double *A, const int8_t *B, uint8_t *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) +static int _matmul_u8_i8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, uint8_t *C, double scale) { + static int initialized = 0; + if (!initialized) { + matmul_feature_t feat = matmul_get_feature(); +#ifdef __AVX512VNNI__ + if (feat & MATMUL_FLAG_AVX512VNNI) + matmul_u8_i8_u8 = matmul_avx512vnni_u8_i8_u8; + else #endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - int sum = 0; - for (size_t k = 0; k < n; k++) { - sum += (int)A[i * n + k] * (int)B[k * p + j]; - } - if (scale != 0 && scale != 1) sum = (int)(sum / scale); - if (sum > 255) - sum = 255; - else if (sum < 0) - sum = 0; - C[i * p + j] = (uint8_t)sum; - } - } - return 0; -} - -int matmul_scalar_f64_u8_f32(size_t m, size_t n, size_t p, const double *A, const uint8_t *B, float *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - float sum = 0; - for (size_t k = 0; k < n; k++) { - sum += (float)A[i * n + k] * (float)B[k * p + j]; - } - if (scale != 0 && scale != 1) sum /= (float)scale; - C[i * p + j] = sum; - } - } - return 0; -} - -int matmul_scalar_f64_u8_f64(size_t m, size_t n, size_t p, const double *A, const uint8_t *B, double *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - double sum = 0; - for (size_t k = 0; k < n; k++) { - sum += A[i * n + k] * (double)B[k * p + j]; - } - if (scale != 0 && scale != 1) sum /= scale; - C[i * p + j] = sum; - } - } - return 0; -} - -int matmul_scalar_f64_u8_i8(size_t m, size_t n, size_t p, const double *A, const uint8_t *B, int8_t *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - int sum = 0; - for (size_t k = 0; k < n; k++) { - sum += (int)A[i * n + k] * (int)B[k * p + j]; - } - if (scale != 0 && scale != 1) sum = (int)(sum / scale); - if (sum > 127) - sum = 127; - else if (sum < -128) - sum = -128; - C[i * p + j] = (int8_t)sum; - } - } - return 0; -} - -int matmul_scalar_f64_u8_u8(size_t m, size_t n, size_t p, const double *A, const uint8_t *B, uint8_t *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - int sum = 0; - for (size_t k = 0; k < n; k++) { - sum += (int)A[i * n + k] * (int)B[k * p + j]; - } - if (scale != 0 && scale != 1) sum = (int)(sum / scale); - if (sum > 255) - sum = 255; - else if (sum < 0) - sum = 0; - C[i * p + j] = (uint8_t)sum; - } - } - return 0; -} - -int matmul_scalar_i8_f32_f32(size_t m, size_t n, size_t p, const int8_t *A, const float *B, float *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - float sum = 0; - for (size_t k = 0; k < n; k++) { - sum += (float)A[i * n + k] * (float)B[k * p + j]; - } - if (scale != 0 && scale != 1) sum /= (float)scale; - C[i * p + j] = sum; - } - } - return 0; -} - -int matmul_scalar_i8_f32_f64(size_t m, size_t n, size_t p, const int8_t *A, const float *B, double *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - double sum = 0; - for (size_t k = 0; k < n; k++) { - sum += (double)A[i * n + k] * (double)B[k * p + j]; - } - if (scale != 0 && scale != 1) sum /= scale; - C[i * p + j] = sum; - } - } - return 0; -} - -int matmul_scalar_i8_f32_i8(size_t m, size_t n, size_t p, const int8_t *A, const float *B, int8_t *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - int sum = 0; - for (size_t k = 0; k < n; k++) { - sum += (int)A[i * n + k] * (int)B[k * p + j]; - } - if (scale != 0 && scale != 1) sum = (int)(sum / scale); - if (sum > 127) - sum = 127; - else if (sum < -128) - sum = -128; - C[i * p + j] = (int8_t)sum; - } - } - return 0; -} - -int matmul_scalar_i8_f32_u8(size_t m, size_t n, size_t p, const int8_t *A, const float *B, uint8_t *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - int sum = 0; - for (size_t k = 0; k < n; k++) { - sum += (int)A[i * n + k] * (int)B[k * p + j]; - } - if (scale != 0 && scale != 1) sum = (int)(sum / scale); - if (sum > 255) - sum = 255; - else if (sum < 0) - sum = 0; - C[i * p + j] = (uint8_t)sum; - } - } - return 0; -} - -int matmul_scalar_i8_f64_f32(size_t m, size_t n, size_t p, const int8_t *A, const double *B, float *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - float sum = 0; - for (size_t k = 0; k < n; k++) { - sum += (float)A[i * n + k] * (float)B[k * p + j]; - } - if (scale != 0 && scale != 1) sum /= (float)scale; - C[i * p + j] = sum; - } - } - return 0; -} - -int matmul_scalar_i8_f64_f64(size_t m, size_t n, size_t p, const int8_t *A, const double *B, double *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - double sum = 0; - for (size_t k = 0; k < n; k++) { - sum += (double)A[i * n + k] * B[k * p + j]; - } - if (scale != 0 && scale != 1) sum /= scale; - C[i * p + j] = sum; - } - } - return 0; -} - -int matmul_scalar_i8_f64_i8(size_t m, size_t n, size_t p, const int8_t *A, const double *B, int8_t *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - int sum = 0; - for (size_t k = 0; k < n; k++) { - sum += (int)A[i * n + k] * (int)B[k * p + j]; - } - if (scale != 0 && scale != 1) sum = (int)(sum / scale); - if (sum > 127) - sum = 127; - else if (sum < -128) - sum = -128; - C[i * p + j] = (int8_t)sum; - } - } - return 0; -} - -int matmul_scalar_i8_f64_u8(size_t m, size_t n, size_t p, const int8_t *A, const double *B, uint8_t *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - int sum = 0; - for (size_t k = 0; k < n; k++) { - sum += (int)A[i * n + k] * (int)B[k * p + j]; - } - if (scale != 0 && scale != 1) sum = (int)(sum / scale); - if (sum > 255) - sum = 255; - else if (sum < 0) - sum = 0; - C[i * p + j] = (uint8_t)sum; - } - } - return 0; -} - -int matmul_scalar_i8_i8_f32(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, float *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - float sum = 0; - for (size_t k = 0; k < n; k++) { - sum += (float)A[i * n + k] * (float)B[k * p + j]; - } - if (scale != 0 && scale != 1) sum /= (float)scale; - C[i * p + j] = sum; - } - } - return 0; -} - -int matmul_scalar_i8_i8_f64(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, double *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - double sum = 0; - for (size_t k = 0; k < n; k++) { - sum += (double)A[i * n + k] * (double)B[k * p + j]; - } - if (scale != 0 && scale != 1) sum /= scale; - C[i * p + j] = sum; - } - } - return 0; -} - -int matmul_scalar_i8_i8_i8(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, int8_t *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - int sum = 0; - for (size_t k = 0; k < n; k++) { - sum += (int)A[i * n + k] * (int)B[k * p + j]; - } - if (scale != 0 && scale != 1) sum = (int)(sum / scale); - if (sum > 127) - sum = 127; - else if (sum < -128) - sum = -128; - C[i * p + j] = (int8_t)sum; - } - } - return 0; -} - -int matmul_scalar_i8_i8_u8(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, uint8_t *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - int sum = 0; - for (size_t k = 0; k < n; k++) { - sum += (int)A[i * n + k] * (int)B[k * p + j]; - } - if (scale != 0 && scale != 1) sum = (int)(sum / scale); - if (sum > 255) - sum = 255; - else if (sum < 0) - sum = 0; - C[i * p + j] = (uint8_t)sum; - } - } - return 0; -} - -int matmul_scalar_i8_u8_f32(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, float *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - float sum = 0; - for (size_t k = 0; k < n; k++) { - sum += (float)A[i * n + k] * (float)B[k * p + j]; - } - if (scale != 0 && scale != 1) sum /= (float)scale; - C[i * p + j] = sum; - } - } - return 0; -} - -int matmul_scalar_i8_u8_f64(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, double *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - double sum = 0; - for (size_t k = 0; k < n; k++) { - sum += (double)A[i * n + k] * (double)B[k * p + j]; - } - if (scale != 0 && scale != 1) sum /= scale; - C[i * p + j] = sum; - } - } - return 0; -} - -int matmul_scalar_i8_u8_i8(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, int8_t *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - int sum = 0; - for (size_t k = 0; k < n; k++) { - sum += (int)A[i * n + k] * (int)B[k * p + j]; - } - if (scale != 0 && scale != 1) sum = (int)(sum / scale); - if (sum > 127) - sum = 127; - else if (sum < -128) - sum = -128; - C[i * p + j] = (int8_t)sum; - } - } - return 0; -} - -int matmul_scalar_i8_u8_u8(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, uint8_t *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - int sum = 0; - for (size_t k = 0; k < n; k++) { - sum += (int)A[i * n + k] * (int)B[k * p + j]; - } - if (scale != 0 && scale != 1) sum = (int)(sum / scale); - if (sum > 255) - sum = 255; - else if (sum < 0) - sum = 0; - C[i * p + j] = (uint8_t)sum; - } - } - return 0; -} - -int matmul_scalar_u8_f32_f32(size_t m, size_t n, size_t p, const uint8_t *A, const float *B, float *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - float sum = 0; - for (size_t k = 0; k < n; k++) { - sum += (float)A[i * n + k] * (float)B[k * p + j]; - } - if (scale != 0 && scale != 1) sum /= (float)scale; - C[i * p + j] = sum; - } - } - return 0; -} - -int matmul_scalar_u8_f32_f64(size_t m, size_t n, size_t p, const uint8_t *A, const float *B, double *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - double sum = 0; - for (size_t k = 0; k < n; k++) { - sum += (double)A[i * n + k] * (double)B[k * p + j]; - } - if (scale != 0 && scale != 1) sum /= scale; - C[i * p + j] = sum; - } - } - return 0; -} - -int matmul_scalar_u8_f32_i8(size_t m, size_t n, size_t p, const uint8_t *A, const float *B, int8_t *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - int sum = 0; - for (size_t k = 0; k < n; k++) { - sum += (int)A[i * n + k] * (int)B[k * p + j]; - } - if (scale != 0 && scale != 1) sum = (int)(sum / scale); - if (sum > 127) - sum = 127; - else if (sum < -128) - sum = -128; - C[i * p + j] = (int8_t)sum; - } - } - return 0; -} - -int matmul_scalar_u8_f32_u8(size_t m, size_t n, size_t p, const uint8_t *A, const float *B, uint8_t *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - int sum = 0; - for (size_t k = 0; k < n; k++) { - sum += (int)A[i * n + k] * (int)B[k * p + j]; - } - if (scale != 0 && scale != 1) sum = (int)(sum / scale); - if (sum > 255) - sum = 255; - else if (sum < 0) - sum = 0; - C[i * p + j] = (uint8_t)sum; - } - } - return 0; -} - -int matmul_scalar_u8_f64_f32(size_t m, size_t n, size_t p, const uint8_t *A, const double *B, float *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - float sum = 0; - for (size_t k = 0; k < n; k++) { - sum += (float)A[i * n + k] * (float)B[k * p + j]; - } - if (scale != 0 && scale != 1) sum /= (float)scale; - C[i * p + j] = sum; - } - } - return 0; -} - -int matmul_scalar_u8_f64_f64(size_t m, size_t n, size_t p, const uint8_t *A, const double *B, double *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - double sum = 0; - for (size_t k = 0; k < n; k++) { - sum += (double)A[i * n + k] * B[k * p + j]; - } - if (scale != 0 && scale != 1) sum /= scale; - C[i * p + j] = sum; - } - } - return 0; -} - -int matmul_scalar_u8_f64_i8(size_t m, size_t n, size_t p, const uint8_t *A, const double *B, int8_t *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - int sum = 0; - for (size_t k = 0; k < n; k++) { - sum += (int)A[i * n + k] * (int)B[k * p + j]; - } - if (scale != 0 && scale != 1) sum = (int)(sum / scale); - if (sum > 127) - sum = 127; - else if (sum < -128) - sum = -128; - C[i * p + j] = (int8_t)sum; - } - } - return 0; -} - -int matmul_scalar_u8_f64_u8(size_t m, size_t n, size_t p, const uint8_t *A, const double *B, uint8_t *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - int sum = 0; - for (size_t k = 0; k < n; k++) { - sum += (int)A[i * n + k] * (int)B[k * p + j]; - } - if (scale != 0 && scale != 1) sum = (int)(sum / scale); - if (sum > 255) - sum = 255; - else if (sum < 0) - sum = 0; - C[i * p + j] = (uint8_t)sum; - } - } - return 0; -} - -int matmul_scalar_u8_i8_f32(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, float *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - float sum = 0; - for (size_t k = 0; k < n; k++) { - sum += (float)A[i * n + k] * (float)B[k * p + j]; - } - if (scale != 0 && scale != 1) sum /= (float)scale; - C[i * p + j] = sum; - } - } - return 0; -} - -int matmul_scalar_u8_i8_f64(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, double *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - double sum = 0; - for (size_t k = 0; k < n; k++) { - sum += (double)A[i * n + k] * (double)B[k * p + j]; - } - if (scale != 0 && scale != 1) sum /= scale; - C[i * p + j] = sum; - } - } - return 0; -} - -int matmul_scalar_u8_i8_i8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, int8_t *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - int sum = 0; - for (size_t k = 0; k < n; k++) { - sum += (int)A[i * n + k] * (int)B[k * p + j]; - } - if (scale != 0 && scale != 1) sum = (int)(sum / scale); - if (sum > 127) - sum = 127; - else if (sum < -128) - sum = -128; - C[i * p + j] = (int8_t)sum; - } - } - return 0; -} - -int matmul_scalar_u8_i8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, uint8_t *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - int sum = 0; - for (size_t k = 0; k < n; k++) { - sum += (int)A[i * n + k] * (int)B[k * p + j]; - } - if (scale != 0 && scale != 1) sum = (int)(sum / scale); - if (sum > 255) - sum = 255; - else if (sum < 0) - sum = 0; - C[i * p + j] = (uint8_t)sum; - } - } - return 0; -} - -int matmul_scalar_u8_u8_f32(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, float *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - float sum = 0; - for (size_t k = 0; k < n; k++) { - sum += (float)A[i * n + k] * (float)B[k * p + j]; - } - if (scale != 0 && scale != 1) sum /= (float)scale; - C[i * p + j] = sum; - } - } - return 0; -} - -int matmul_scalar_u8_u8_f64(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, double *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - double sum = 0; - for (size_t k = 0; k < n; k++) { - sum += (double)A[i * n + k] * (double)B[k * p + j]; - } - if (scale != 0 && scale != 1) sum /= scale; - C[i * p + j] = sum; - } - } - return 0; -} - -int matmul_scalar_u8_u8_i8(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, int8_t *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - int sum = 0; - for (size_t k = 0; k < n; k++) { - sum += (int)A[i * n + k] * (int)B[k * p + j]; - } - if (scale != 0 && scale != 1) sum = (int)(sum / scale); - if (sum > 127) - sum = 127; - else if (sum < -128) - sum = -128; - C[i * p + j] = (int8_t)sum; - } - } - return 0; -} - -#ifdef __AVX2__ -int matmul_avx2_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const float *B, float *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - __m256 sum_vec = _mm256_setzero_ps(); - size_t k = 0; - for (; k + 7 < n; k += 8) { - __m256 a = _mm256_loadu_ps(&A[i * n + k]); - __m256 b = _mm256_loadu_ps(&B[k * p + j]); - sum_vec = _mm256_fmadd_ps(a, b, sum_vec); - } - float sum[8]; - _mm256_storeu_ps(sum, sum_vec); - float s = sum[0] + sum[1] + sum[2] + sum[3] + sum[4] + sum[5] + sum[6] + sum[7]; - for (; k < n; k++) s += A[i * n + k] * B[k * p + j]; - if (scale != 0 && scale != 1) s /= (float)scale; - C[i * p + j] = s; - } - } - return 0; -} - -int matmul_avx2_f32_f32_f64(size_t m, size_t n, size_t p, const float *A, const float *B, double *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - __m256d sum_vec = _mm256_setzero_pd(); - size_t k = 0; - for (; k + 3 < n; k += 4) { - __m256 a = _mm256_castpd_ps(_mm256_loadu_pd((const double *)&A[i * n + k])); - __m256 b = _mm256_castpd_ps(_mm256_loadu_pd((const double *)&B[k * p + j])); - __m256 mul = _mm256_mul_ps(a, b); - sum_vec = _mm256_add_pd(_mm256_castps_pd(mul), sum_vec); - } - double sum[4]; - _mm256_storeu_pd(sum, sum_vec); - double s = sum[0] + sum[1] + sum[2] + sum[3]; - for (; k < n; k++) s += (double)A[i * n + k] * (double)B[k * p + j]; - if (scale != 0 && scale != 1) s /= scale; - C[i * p + j] = s; - } - } - return 0; -} - -int matmul_avx2_f32_f64_f32(size_t m, size_t n, size_t p, const float *A, const double *B, float *C, double scale) { - return matmul_scalar_f32_f64_f32(m, n, p, A, B, C, scale); -} - -int matmul_avx2_f32_f64_f64(size_t m, size_t n, size_t p, const float *A, const double *B, double *C, double scale) { - return matmul_scalar_f32_f64_f64(m, n, p, A, B, C, scale); -} - -int matmul_avx2_f64_f32_f32(size_t m, size_t n, size_t p, const double *A, const float *B, float *C, double scale) { - return matmul_scalar_f64_f32_f32(m, n, p, A, B, C, scale); -} - -int matmul_avx2_f64_f32_f64(size_t m, size_t n, size_t p, const double *A, const float *B, double *C, double scale) { - return matmul_scalar_f64_f32_f64(m, n, p, A, B, C, scale); -} - -int matmul_avx2_f64_f64_f32(size_t m, size_t n, size_t p, const double *A, const double *B, float *C, double scale) { - return matmul_scalar_f64_f64_f32(m, n, p, A, B, C, scale); -} - -int matmul_avx2_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, const double *B, double *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - __m256d sum_vec = _mm256_setzero_pd(); - size_t k = 0; - for (; k + 3 < n; k += 4) { - __m256d a = _mm256_loadu_pd(&A[i * n + k]); - __m256d b = _mm256_loadu_pd(&B[k * p + j]); - sum_vec = _mm256_fmadd_pd(a, b, sum_vec); - } - double sum[4]; - _mm256_storeu_pd(sum, sum_vec); - double s = sum[0] + sum[1] + sum[2] + sum[3]; - for (; k < n; k++) s += A[i * n + k] * B[k * p + j]; - if (scale != 0 && scale != 1) s /= scale; - C[i * p + j] = s; - } - } - return 0; -} -#endif - -#ifdef __AVX512F__ -int matmul_avx512_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const float *B, float *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - __m512 sum_vec = _mm512_setzero_ps(); - size_t k = 0; - for (; k + 15 < n; k += 16) { - __m512 a = _mm512_loadu_ps(&A[i * n + k]); - __m512 b = _mm512_loadu_ps(&B[k * p + j]); - sum_vec = _mm512_fmadd_ps(a, b, sum_vec); - } - float s = _mm512_reduce_add_ps(sum_vec); - for (; k < n; k++) s += A[i * n + k] * B[k * p + j]; - if (scale != 0 && scale != 1) s /= (float)scale; - C[i * p + j] = s; - } - } - return 0; -} - -int matmul_avx512_f32_f32_f64(size_t m, size_t n, size_t p, const float *A, const float *B, double *C, double scale) { - return matmul_scalar_f32_f32_f64(m, n, p, A, B, C, scale); -} - -int matmul_avx512_f32_f64_f32(size_t m, size_t n, size_t p, const float *A, const double *B, float *C, double scale) { - return matmul_scalar_f32_f64_f32(m, n, p, A, B, C, scale); -} - -int matmul_avx512_f32_f64_f64(size_t m, size_t n, size_t p, const float *A, const double *B, double *C, double scale) { - return matmul_scalar_f32_f64_f64(m, n, p, A, B, C, scale); -} - -int matmul_avx512_f64_f32_f32(size_t m, size_t n, size_t p, const double *A, const float *B, float *C, double scale) { - return matmul_scalar_f64_f32_f32(m, n, p, A, B, C, scale); -} - -int matmul_avx512_f64_f32_f64(size_t m, size_t n, size_t p, const double *A, const float *B, double *C, double scale) { - return matmul_scalar_f64_f32_f64(m, n, p, A, B, C, scale); -} - -int matmul_avx512_f64_f64_f32(size_t m, size_t n, size_t p, const double *A, const double *B, float *C, double scale) { - return matmul_scalar_f64_f64_f32(m, n, p, A, B, C, scale); -} - -int matmul_avx512_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, const double *B, double *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - __m512d sum_vec = _mm512_setzero_pd(); - size_t k = 0; - for (; k + 7 < n; k += 8) { - __m512d a = _mm512_loadu_pd(&A[i * n + k]); - __m512d b = _mm512_loadu_pd(&B[k * p + j]); - sum_vec = _mm512_fmadd_pd(a, b, sum_vec); - } - double s = _mm512_reduce_add_pd(sum_vec); - for (; k < n; k++) s += A[i * n + k] * B[k * p + j]; - if (scale != 0 && scale != 1) s /= scale; - C[i * p + j] = s; - } - } - return 0; -} -#endif - -#ifdef __AVX2__ -static inline int32_t reduce_add_i32x8(__m256i v) { - __m128i low = _mm256_extracti128_si256(v, 0); - __m128i high = _mm256_extracti128_si256(v, 1); - __m128i sum = _mm_add_epi32(low, high); - sum = _mm_hadd_epi32(sum, sum); - sum = _mm_hadd_epi32(sum, sum); - return _mm_cvtsi128_si32(sum); -} - -int matmul_avx2_i8_i8_f32(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, float *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - __m256i sum_vec = _mm256_setzero_si256(); - size_t k = 0; - for (; k + 31 < n; k += 32) { - __m256i a_lo = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k])); - __m256i a_hi = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16])); - __m256i b_lo = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j])); - __m256i b_hi = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16])); - __m256i mul_lo = _mm256_madd_epi16(a_lo, b_lo); - __m256i mul_hi = _mm256_madd_epi16(a_hi, b_hi); - sum_vec = _mm256_add_epi32(sum_vec, mul_lo); - sum_vec = _mm256_add_epi32(sum_vec, mul_hi); - } - int s = reduce_add_i32x8(sum_vec); - for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j]; - if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale); - C[i * p + j] = (float)s; - } - } - return 0; -} - -int matmul_avx2_i8_i8_f64(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, double *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - __m256i sum_vec = _mm256_setzero_si256(); - size_t k = 0; - for (; k + 31 < n; k += 32) { - __m256i a_lo = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k])); - __m256i a_hi = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16])); - __m256i b_lo = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j])); - __m256i b_hi = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16])); - __m256i mul_lo = _mm256_madd_epi16(a_lo, b_lo); - __m256i mul_hi = _mm256_madd_epi16(a_hi, b_hi); - sum_vec = _mm256_add_epi32(sum_vec, mul_lo); - sum_vec = _mm256_add_epi32(sum_vec, mul_hi); - } - int32_t sum[8]; - _mm256_storeu_si256((__m256i *)sum, sum_vec); - int s = reduce_add_i32x8(sum_vec); - for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j]; - if (scale != 0 && scale != 1) s = (int)((double)s / scale); - C[i * p + j] = (double)s; - } - } - return 0; -} - -int matmul_avx2_i8_i8_i8(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, int8_t *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - __m256i sum_vec = _mm256_setzero_si256(); - size_t k = 0; - for (; k + 31 < n; k += 32) { - __m256i a_lo = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k])); - __m256i a_hi = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16])); - __m256i b_lo = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j])); - __m256i b_hi = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16])); - __m256i mul_lo = _mm256_madd_epi16(a_lo, b_lo); - __m256i mul_hi = _mm256_madd_epi16(a_hi, b_hi); - sum_vec = _mm256_add_epi32(sum_vec, mul_lo); - sum_vec = _mm256_add_epi32(sum_vec, mul_hi); - } - int32_t sum[8]; - _mm256_storeu_si256((__m256i *)sum, sum_vec); - int s = reduce_add_i32x8(sum_vec); - for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j]; - if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale); - if (s > 127) - s = 127; - else if (s < -128) - s = -128; - C[i * p + j] = (int8_t)s; - } - } - return 0; -} - -int matmul_avx2_i8_i8_u8(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, uint8_t *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - __m256i sum_vec = _mm256_setzero_si256(); - size_t k = 0; - for (; k + 31 < n; k += 32) { - __m256i a_lo = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k])); - __m256i a_hi = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16])); - __m256i b_lo = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j])); - __m256i b_hi = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16])); - __m256i mul_lo = _mm256_madd_epi16(a_lo, b_lo); - __m256i mul_hi = _mm256_madd_epi16(a_hi, b_hi); - sum_vec = _mm256_add_epi32(sum_vec, mul_lo); - sum_vec = _mm256_add_epi32(sum_vec, mul_hi); - } - int32_t sum[8]; - _mm256_storeu_si256((__m256i *)sum, sum_vec); - int s = reduce_add_i32x8(sum_vec); - for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j]; - if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale); - if (s > 255) - s = 255; - else if (s < 0) - s = 0; - C[i * p + j] = (uint8_t)s; - } - } - return 0; -} - -int matmul_avx2_i8_u8_f32(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, float *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - __m256i sum_vec = _mm256_setzero_si256(); - size_t k = 0; - for (; k + 31 < n; k += 32) { - __m256i a_lo = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k])); - __m256i a_hi = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16])); - __m256i b_lo = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j])); - __m256i b_hi = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16])); - __m256i mul_lo = _mm256_madd_epi16(a_lo, b_lo); - __m256i mul_hi = _mm256_madd_epi16(a_hi, b_hi); - sum_vec = _mm256_add_epi32(sum_vec, mul_lo); - sum_vec = _mm256_add_epi32(sum_vec, mul_hi); - } - int s = reduce_add_i32x8(sum_vec); - for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j]; - if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale); - C[i * p + j] = (float)s; - } - } - return 0; -} - -int matmul_avx2_i8_u8_f64(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, double *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - __m256i sum_vec = _mm256_setzero_si256(); - size_t k = 0; - for (; k + 31 < n; k += 32) { - __m256i a_lo = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k])); - __m256i a_hi = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16])); - __m256i b_lo = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j])); - __m256i b_hi = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16])); - __m256i mul_lo = _mm256_madd_epi16(a_lo, b_lo); - __m256i mul_hi = _mm256_madd_epi16(a_hi, b_hi); - sum_vec = _mm256_add_epi32(sum_vec, mul_lo); - sum_vec = _mm256_add_epi32(sum_vec, mul_hi); - } - int s = reduce_add_i32x8(sum_vec); - for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j]; - if (scale != 0 && scale != 1) s = (int)((double)s / scale); - C[i * p + j] = (double)s; - } - } - return 0; -} - -int matmul_avx2_i8_u8_i8(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, int8_t *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - __m256i sum_vec = _mm256_setzero_si256(); - size_t k = 0; - for (; k + 31 < n; k += 32) { - __m256i a_lo = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k])); - __m256i a_hi = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16])); - __m256i b_lo = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j])); - __m256i b_hi = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16])); - __m256i mul_lo = _mm256_madd_epi16(a_lo, b_lo); - __m256i mul_hi = _mm256_madd_epi16(a_hi, b_hi); - sum_vec = _mm256_add_epi32(sum_vec, mul_lo); - sum_vec = _mm256_add_epi32(sum_vec, mul_hi); - } - int s = reduce_add_i32x8(sum_vec); - for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j]; - if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale); - if (s > 127) - s = 127; - else if (s < -128) - s = -128; - C[i * p + j] = (int8_t)s; - } - } - return 0; -} - -int matmul_avx2_i8_u8_u8(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, uint8_t *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - __m256i sum_vec = _mm256_setzero_si256(); - size_t k = 0; - for (; k + 31 < n; k += 32) { - __m256i a_lo = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k])); - __m256i a_hi = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16])); - __m256i b_lo = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j])); - __m256i b_hi = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16])); - __m256i mul_lo = _mm256_madd_epi16(a_lo, b_lo); - __m256i mul_hi = _mm256_madd_epi16(a_hi, b_hi); - sum_vec = _mm256_add_epi32(sum_vec, mul_lo); - sum_vec = _mm256_add_epi32(sum_vec, mul_hi); - } - int s = reduce_add_i32x8(sum_vec); - for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j]; - if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale); - if (s > 255) - s = 255; - else if (s < 0) - s = 0; - C[i * p + j] = (uint8_t)s; - } - } - return 0; -} - -int matmul_avx2_u8_i8_f32(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, float *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - __m256i sum_vec = _mm256_setzero_si256(); - size_t k = 0; - for (; k + 31 < n; k += 32) { - __m256i a_lo = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k])); - __m256i a_hi = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16])); - __m256i b_lo = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j])); - __m256i b_hi = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16])); - __m256i mul_lo = _mm256_madd_epi16(a_lo, b_lo); - __m256i mul_hi = _mm256_madd_epi16(a_hi, b_hi); - sum_vec = _mm256_add_epi32(sum_vec, mul_lo); - sum_vec = _mm256_add_epi32(sum_vec, mul_hi); - } - int s = reduce_add_i32x8(sum_vec); - for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j]; - if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale); - C[i * p + j] = (float)s; - } - } - return 0; -} - -int matmul_avx2_u8_i8_f64(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, double *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - __m256i sum_vec = _mm256_setzero_si256(); - size_t k = 0; - for (; k + 31 < n; k += 32) { - __m256i a_lo = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k])); - __m256i a_hi = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16])); - __m256i b_lo = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j])); - __m256i b_hi = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16])); - __m256i mul_lo = _mm256_madd_epi16(a_lo, b_lo); - __m256i mul_hi = _mm256_madd_epi16(a_hi, b_hi); - sum_vec = _mm256_add_epi32(sum_vec, mul_lo); - sum_vec = _mm256_add_epi32(sum_vec, mul_hi); - } - int s = reduce_add_i32x8(sum_vec); - for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j]; - if (scale != 0 && scale != 1) s = (int)((double)s / scale); - C[i * p + j] = (double)s; - } - } - return 0; -} - -int matmul_avx2_u8_i8_i8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, int8_t *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - __m256i sum_vec = _mm256_setzero_si256(); - size_t k = 0; - for (; k + 31 < n; k += 32) { - __m256i a_lo = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k])); - __m256i a_hi = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16])); - __m256i b_lo = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j])); - __m256i b_hi = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16])); - __m256i mul_lo = _mm256_madd_epi16(a_lo, b_lo); - __m256i mul_hi = _mm256_madd_epi16(a_hi, b_hi); - sum_vec = _mm256_add_epi32(sum_vec, mul_lo); - sum_vec = _mm256_add_epi32(sum_vec, mul_hi); - } - int s = reduce_add_i32x8(sum_vec); - for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j]; - if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale); - if (s > 127) - s = 127; - else if (s < -128) - s = -128; - C[i * p + j] = (int8_t)s; - } - } - return 0; -} - -int matmul_avx2_u8_i8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, uint8_t *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - __m256i sum_vec = _mm256_setzero_si256(); - size_t k = 0; - for (; k + 31 < n; k += 32) { - __m256i a_lo = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k])); - __m256i a_hi = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16])); - __m256i b_lo = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j])); - __m256i b_hi = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16])); - __m256i mul_lo = _mm256_madd_epi16(a_lo, b_lo); - __m256i mul_hi = _mm256_madd_epi16(a_hi, b_hi); - sum_vec = _mm256_add_epi32(sum_vec, mul_lo); - sum_vec = _mm256_add_epi32(sum_vec, mul_hi); - } - int s = reduce_add_i32x8(sum_vec); - for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j]; - if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale); - if (s > 255) - s = 255; - else if (s < 0) - s = 0; - C[i * p + j] = (uint8_t)s; - } - } - return 0; -} - -int matmul_avx2_u8_u8_f32(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, float *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - __m256i sum_vec = _mm256_setzero_si256(); - size_t k = 0; - for (; k + 31 < n; k += 32) { - __m256i a_lo = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k])); - __m256i a_hi = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16])); - __m256i b_lo = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j])); - __m256i b_hi = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16])); - __m256i mul_lo = _mm256_madd_epi16(a_lo, b_lo); - __m256i mul_hi = _mm256_madd_epi16(a_hi, b_hi); - sum_vec = _mm256_add_epi32(sum_vec, mul_lo); - sum_vec = _mm256_add_epi32(sum_vec, mul_hi); - } - int s = reduce_add_i32x8(sum_vec); - for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j]; - if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale); - C[i * p + j] = (float)s; - } - } - return 0; -} - -int matmul_avx2_u8_u8_f64(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, double *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - __m256i sum_vec = _mm256_setzero_si256(); - size_t k = 0; - for (; k + 31 < n; k += 32) { - __m256i a_lo = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k])); - __m256i a_hi = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16])); - __m256i b_lo = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j])); - __m256i b_hi = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16])); - __m256i mul_lo = _mm256_madd_epi16(a_lo, b_lo); - __m256i mul_hi = _mm256_madd_epi16(a_hi, b_hi); - sum_vec = _mm256_add_epi32(sum_vec, mul_lo); - sum_vec = _mm256_add_epi32(sum_vec, mul_hi); - } - int s = reduce_add_i32x8(sum_vec); - for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j]; - if (scale != 0 && scale != 1) s = (int)((double)s / scale); - C[i * p + j] = (double)s; - } - } - return 0; -} - -int matmul_avx2_u8_u8_i8(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, int8_t *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - __m256i sum_vec = _mm256_setzero_si256(); - size_t k = 0; - for (; k + 31 < n; k += 32) { - __m256i a_lo = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k])); - __m256i a_hi = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16])); - __m256i b_lo = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j])); - __m256i b_hi = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16])); - __m256i mul_lo = _mm256_madd_epi16(a_lo, b_lo); - __m256i mul_hi = _mm256_madd_epi16(a_hi, b_hi); - sum_vec = _mm256_add_epi32(sum_vec, mul_lo); - sum_vec = _mm256_add_epi32(sum_vec, mul_hi); - } - int s = reduce_add_i32x8(sum_vec); - for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j]; - if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale); - if (s > 127) - s = 127; - else if (s < -128) - s = -128; - C[i * p + j] = (int8_t)s; - } - } - return 0; -} - -int matmul_avx2_u8_u8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, uint8_t *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - __m256i sum_vec = _mm256_setzero_si256(); - size_t k = 0; - for (; k + 31 < n; k += 32) { - __m256i a_lo = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k])); - __m256i a_hi = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16])); - __m256i b_lo = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j])); - __m256i b_hi = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16])); - __m256i mul_lo = _mm256_madd_epi16(a_lo, b_lo); - __m256i mul_hi = _mm256_madd_epi16(a_hi, b_hi); - sum_vec = _mm256_add_epi32(sum_vec, mul_lo); - sum_vec = _mm256_add_epi32(sum_vec, mul_hi); - } - int s = reduce_add_i32x8(sum_vec); - for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j]; - if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale); - if (s > 255) - s = 255; - else if (s < 0) - s = 0; - C[i * p + j] = (uint8_t)s; - } - } - return 0; -} -#endif - -#ifdef __AVX512F__ -static inline int32_t reduce_add_i32x16(__m512i v) { - __m256i low = _mm512_extracti64x4_epi64(v, 0); - __m256i high = _mm512_extracti64x4_epi64(v, 1); - __m256i sum = _mm256_add_epi32(low, high); - sum = _mm256_hadd_epi32(sum, sum); - sum = _mm256_hadd_epi32(sum, sum); - __m128i s128 = _mm256_extracti128_si256(sum, 0); - s128 = _mm_hadd_epi32(s128, s128); - s128 = _mm_hadd_epi32(s128, s128); - return _mm_cvtsi128_si32(s128); -} - -int matmul_avx512_i8_i8_f32(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, float *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - __m512i sum_vec = _mm512_setzero_si512(); - size_t k = 0; - for (; k + 63 < n; k += 64) { - __m512i a0 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k])); - __m512i a1 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16])); - __m512i a2 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 32])); - __m512i a3 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 48])); - __m512i b0 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j])); - __m512i b1 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16])); - __m512i b2 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 32])); - __m512i b3 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 48])); - sum_vec = _mm512_dpwssd_epi32(sum_vec, a0, b0); - sum_vec = _mm512_dpwssd_epi32(sum_vec, a1, b1); - sum_vec = _mm512_dpwssd_epi32(sum_vec, a2, b2); - sum_vec = _mm512_dpwssd_epi32(sum_vec, a3, b3); - } - int s = reduce_add_i32x16(sum_vec); - for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j]; - if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale); - C[i * p + j] = (float)s; - } - } - return 0; -} - -int matmul_avx512_i8_i8_f64(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, double *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - __m512i sum_vec = _mm512_setzero_si512(); - size_t k = 0; - for (; k + 63 < n; k += 64) { - __m512i a0 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k])); - __m512i a1 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16])); - __m512i a2 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 32])); - __m512i a3 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 48])); - __m512i b0 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j])); - __m512i b1 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16])); - __m512i b2 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 32])); - __m512i b3 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 48])); - sum_vec = _mm512_dpwssd_epi32(sum_vec, a0, b0); - sum_vec = _mm512_dpwssd_epi32(sum_vec, a1, b1); - sum_vec = _mm512_dpwssd_epi32(sum_vec, a2, b2); - sum_vec = _mm512_dpwssd_epi32(sum_vec, a3, b3); - } - int s = reduce_add_i32x16(sum_vec); - for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j]; - if (scale != 0 && scale != 1) s = (int)((double)s / scale); - C[i * p + j] = (double)s; - } - } - return 0; -} - -int matmul_avx512_i8_i8_i8(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, int8_t *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - __m512i sum_vec = _mm512_setzero_si512(); - size_t k = 0; - for (; k + 63 < n; k += 64) { - __m512i a0 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k])); - __m512i a1 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16])); - __m512i a2 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 32])); - __m512i a3 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 48])); - __m512i b0 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j])); - __m512i b1 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16])); - __m512i b2 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 32])); - __m512i b3 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 48])); - sum_vec = _mm512_dpwssd_epi32(sum_vec, a0, b0); - sum_vec = _mm512_dpwssd_epi32(sum_vec, a1, b1); - sum_vec = _mm512_dpwssd_epi32(sum_vec, a2, b2); - sum_vec = _mm512_dpwssd_epi32(sum_vec, a3, b3); - } - int s = reduce_add_i32x16(sum_vec); - for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j]; - if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale); - if (s > 127) - s = 127; - else if (s < -128) - s = -128; - C[i * p + j] = (int8_t)s; - } - } - return 0; -} - -int matmul_avx512_i8_i8_u8(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, uint8_t *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - __m512i sum_vec = _mm512_setzero_si512(); - size_t k = 0; - for (; k + 63 < n; k += 64) { - __m512i a0 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k])); - __m512i a1 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16])); - __m512i a2 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 32])); - __m512i a3 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 48])); - __m512i b0 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j])); - __m512i b1 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16])); - __m512i b2 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 32])); - __m512i b3 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 48])); - sum_vec = _mm512_dpwssd_epi32(sum_vec, a0, b0); - sum_vec = _mm512_dpwssd_epi32(sum_vec, a1, b1); - sum_vec = _mm512_dpwssd_epi32(sum_vec, a2, b2); - sum_vec = _mm512_dpwssd_epi32(sum_vec, a3, b3); - } - int s = reduce_add_i32x16(sum_vec); - for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j]; - if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale); - if (s > 255) - s = 255; - else if (s < 0) - s = 0; - C[i * p + j] = (uint8_t)s; - } - } - return 0; -} - -int matmul_avx512_i8_u8_f32(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, float *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - __m512i sum_vec = _mm512_setzero_si512(); - size_t k = 0; - for (; k + 63 < n; k += 64) { - __m512i a0 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k])); - __m512i a1 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16])); - __m512i a2 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 32])); - __m512i a3 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 48])); - __m512i b0 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j])); - __m512i b1 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16])); - __m512i b2 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 32])); - __m512i b3 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 48])); - sum_vec = _mm512_dpwssd_epi32(sum_vec, a0, b0); - sum_vec = _mm512_dpwssd_epi32(sum_vec, a1, b1); - sum_vec = _mm512_dpwssd_epi32(sum_vec, a2, b2); - sum_vec = _mm512_dpwssd_epi32(sum_vec, a3, b3); - } - int s = reduce_add_i32x16(sum_vec); - for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j]; - if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale); - C[i * p + j] = (float)s; - } - } - return 0; -} - -int matmul_avx512_i8_u8_f64(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, double *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - __m512i sum_vec = _mm512_setzero_si512(); - size_t k = 0; - for (; k + 63 < n; k += 64) { - __m512i a0 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k])); - __m512i a1 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16])); - __m512i a2 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 32])); - __m512i a3 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 48])); - __m512i b0 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j])); - __m512i b1 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16])); - __m512i b2 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 32])); - __m512i b3 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 48])); - sum_vec = _mm512_dpwssd_epi32(sum_vec, a0, b0); - sum_vec = _mm512_dpwssd_epi32(sum_vec, a1, b1); - sum_vec = _mm512_dpwssd_epi32(sum_vec, a2, b2); - sum_vec = _mm512_dpwssd_epi32(sum_vec, a3, b3); - } - int s = reduce_add_i32x16(sum_vec); - for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j]; - if (scale != 0 && scale != 1) s = (int)((double)s / scale); - C[i * p + j] = (double)s; - } - } - return 0; -} - -int matmul_avx512_i8_u8_i8(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, int8_t *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - __m512i sum_vec = _mm512_setzero_si512(); - size_t k = 0; - for (; k + 63 < n; k += 64) { - __m512i a0 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k])); - __m512i a1 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16])); - __m512i a2 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 32])); - __m512i a3 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 48])); - __m512i b0 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j])); - __m512i b1 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16])); - __m512i b2 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 32])); - __m512i b3 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 48])); - sum_vec = _mm512_dpwssd_epi32(sum_vec, a0, b0); - sum_vec = _mm512_dpwssd_epi32(sum_vec, a1, b1); - sum_vec = _mm512_dpwssd_epi32(sum_vec, a2, b2); - sum_vec = _mm512_dpwssd_epi32(sum_vec, a3, b3); - } - int s = reduce_add_i32x16(sum_vec); - for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j]; - if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale); - if (s > 127) - s = 127; - else if (s < -128) - s = -128; - C[i * p + j] = (int8_t)s; - } - } - return 0; -} - -int matmul_avx512_i8_u8_u8(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, uint8_t *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - __m512i sum_vec = _mm512_setzero_si512(); - size_t k = 0; - for (; k + 63 < n; k += 64) { - __m512i a0 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k])); - __m512i a1 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16])); - __m512i a2 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 32])); - __m512i a3 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 48])); - __m512i b0 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j])); - __m512i b1 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16])); - __m512i b2 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 32])); - __m512i b3 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 48])); - sum_vec = _mm512_dpwssd_epi32(sum_vec, a0, b0); - sum_vec = _mm512_dpwssd_epi32(sum_vec, a1, b1); - sum_vec = _mm512_dpwssd_epi32(sum_vec, a2, b2); - sum_vec = _mm512_dpwssd_epi32(sum_vec, a3, b3); - } - int s = reduce_add_i32x16(sum_vec); - for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j]; - if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale); - if (s > 255) - s = 255; - else if (s < 0) - s = 0; - C[i * p + j] = (uint8_t)s; - } - } - return 0; -} - -int matmul_avx512_u8_i8_f32(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, float *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - __m512i sum_vec = _mm512_setzero_si512(); - size_t k = 0; - for (; k + 63 < n; k += 64) { - __m512i a0 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k])); - __m512i a1 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16])); - __m512i a2 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 32])); - __m512i a3 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 48])); - __m512i b0 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j])); - __m512i b1 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16])); - __m512i b2 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 32])); - __m512i b3 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 48])); - sum_vec = _mm512_dpwssd_epi32(sum_vec, a0, b0); - sum_vec = _mm512_dpwssd_epi32(sum_vec, a1, b1); - sum_vec = _mm512_dpwssd_epi32(sum_vec, a2, b2); - sum_vec = _mm512_dpwssd_epi32(sum_vec, a3, b3); - } - int s = reduce_add_i32x16(sum_vec); - for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j]; - if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale); - C[i * p + j] = (float)s; - } - } - return 0; -} - -int matmul_avx512_u8_i8_f64(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, double *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - __m512i sum_vec = _mm512_setzero_si512(); - size_t k = 0; - for (; k + 63 < n; k += 64) { - __m512i a0 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k])); - __m512i a1 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16])); - __m512i a2 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 32])); - __m512i a3 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 48])); - __m512i b0 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j])); - __m512i b1 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16])); - __m512i b2 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 32])); - __m512i b3 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 48])); - sum_vec = _mm512_dpwssd_epi32(sum_vec, a0, b0); - sum_vec = _mm512_dpwssd_epi32(sum_vec, a1, b1); - sum_vec = _mm512_dpwssd_epi32(sum_vec, a2, b2); - sum_vec = _mm512_dpwssd_epi32(sum_vec, a3, b3); - } - int s = reduce_add_i32x16(sum_vec); - for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j]; - if (scale != 0 && scale != 1) s = (int)((double)s / scale); - C[i * p + j] = (double)s; - } - } - return 0; -} - -int matmul_avx512_u8_i8_i8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, int8_t *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - __m512i sum_vec = _mm512_setzero_si512(); - size_t k = 0; - for (; k + 63 < n; k += 64) { - __m512i a0 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k])); - __m512i a1 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16])); - __m512i a2 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 32])); - __m512i a3 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 48])); - __m512i b0 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j])); - __m512i b1 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16])); - __m512i b2 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 32])); - __m512i b3 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 48])); - sum_vec = _mm512_dpwssd_epi32(sum_vec, a0, b0); - sum_vec = _mm512_dpwssd_epi32(sum_vec, a1, b1); - sum_vec = _mm512_dpwssd_epi32(sum_vec, a2, b2); - sum_vec = _mm512_dpwssd_epi32(sum_vec, a3, b3); - } - int s = reduce_add_i32x16(sum_vec); - for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j]; - if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale); - if (s > 127) - s = 127; - else if (s < -128) - s = -128; - C[i * p + j] = (int8_t)s; - } - } - return 0; -} - -int matmul_avx512_u8_i8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, uint8_t *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - __m512i sum_vec = _mm512_setzero_si512(); - size_t k = 0; - for (; k + 63 < n; k += 64) { - __m512i a0 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k])); - __m512i a1 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16])); - __m512i a2 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 32])); - __m512i a3 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 48])); - __m512i b0 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j])); - __m512i b1 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16])); - __m512i b2 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 32])); - __m512i b3 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 48])); - sum_vec = _mm512_dpwssd_epi32(sum_vec, a0, b0); - sum_vec = _mm512_dpwssd_epi32(sum_vec, a1, b1); - sum_vec = _mm512_dpwssd_epi32(sum_vec, a2, b2); - sum_vec = _mm512_dpwssd_epi32(sum_vec, a3, b3); - } - int s = reduce_add_i32x16(sum_vec); - for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j]; - if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale); - if (s > 255) - s = 255; - else if (s < 0) - s = 0; - C[i * p + j] = (uint8_t)s; - } - } - return 0; -} - -int matmul_avx512_u8_u8_f32(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, float *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - __m512i sum_vec = _mm512_setzero_si512(); - size_t k = 0; - for (; k + 63 < n; k += 64) { - __m512i a0 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k])); - __m512i a1 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16])); - __m512i a2 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 32])); - __m512i a3 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 48])); - __m512i b0 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j])); - __m512i b1 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16])); - __m512i b2 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 32])); - __m512i b3 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 48])); - sum_vec = _mm512_dpwssd_epi32(sum_vec, a0, b0); - sum_vec = _mm512_dpwssd_epi32(sum_vec, a1, b1); - sum_vec = _mm512_dpwssd_epi32(sum_vec, a2, b2); - sum_vec = _mm512_dpwssd_epi32(sum_vec, a3, b3); - } - int s = reduce_add_i32x16(sum_vec); - for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j]; - if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale); - C[i * p + j] = (float)s; - } - } - return 0; -} - -int matmul_avx512_u8_u8_f64(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, double *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - __m512i sum_vec = _mm512_setzero_si512(); - size_t k = 0; - for (; k + 63 < n; k += 64) { - __m512i a0 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k])); - __m512i a1 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16])); - __m512i a2 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 32])); - __m512i a3 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 48])); - __m512i b0 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j])); - __m512i b1 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16])); - __m512i b2 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 32])); - __m512i b3 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 48])); - sum_vec = _mm512_dpwssd_epi32(sum_vec, a0, b0); - sum_vec = _mm512_dpwssd_epi32(sum_vec, a1, b1); - sum_vec = _mm512_dpwssd_epi32(sum_vec, a2, b2); - sum_vec = _mm512_dpwssd_epi32(sum_vec, a3, b3); - } - int s = reduce_add_i32x16(sum_vec); - for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j]; - if (scale != 0 && scale != 1) s = (int)((double)s / scale); - C[i * p + j] = (double)s; - } - } - return 0; -} - -int matmul_avx512_u8_u8_i8(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, int8_t *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - __m512i sum_vec = _mm512_setzero_si512(); - size_t k = 0; - for (; k + 63 < n; k += 64) { - __m512i a0 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k])); - __m512i a1 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16])); - __m512i a2 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 32])); - __m512i a3 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 48])); - __m512i b0 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j])); - __m512i b1 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16])); - __m512i b2 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 32])); - __m512i b3 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 48])); - sum_vec = _mm512_dpwssd_epi32(sum_vec, a0, b0); - sum_vec = _mm512_dpwssd_epi32(sum_vec, a1, b1); - sum_vec = _mm512_dpwssd_epi32(sum_vec, a2, b2); - sum_vec = _mm512_dpwssd_epi32(sum_vec, a3, b3); - } - int s = reduce_add_i32x16(sum_vec); - for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j]; - if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale); - if (s > 127) - s = 127; - else if (s < -128) - s = -128; - C[i * p + j] = (int8_t)s; - } - } - return 0; -} - -int matmul_avx512_u8_u8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, uint8_t *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - __m512i sum_vec = _mm512_setzero_si512(); - size_t k = 0; - for (; k + 63 < n; k += 64) { - __m512i a0 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k])); - __m512i a1 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16])); - __m512i a2 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 32])); - __m512i a3 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 48])); - __m512i b0 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j])); - __m512i b1 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16])); - __m512i b2 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 32])); - __m512i b3 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 48])); - sum_vec = _mm512_dpwssd_epi32(sum_vec, a0, b0); - sum_vec = _mm512_dpwssd_epi32(sum_vec, a1, b1); - sum_vec = _mm512_dpwssd_epi32(sum_vec, a2, b2); - sum_vec = _mm512_dpwssd_epi32(sum_vec, a3, b3); - } - int s = reduce_add_i32x16(sum_vec); - for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j]; - if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale); - if (s > 255) - s = 255; - else if (s < 0) - s = 0; - C[i * p + j] = (uint8_t)s; - } - } - return 0; -} -#endif - -int matmul_scalar_u8_u8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, uint8_t *C, double scale) { -#ifdef _OPENMP -#pragma omp parallel for schedule(static) -#endif - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < p; j++) { - int sum = 0; - for (size_t k = 0; k < n; k++) { - sum += (int)A[i * n + k] * (int)B[k * p + j]; - } - if (scale != 0 && scale != 1) sum = (int)(sum / scale); - if (sum > 255) - sum = 255; - else if (sum < 0) - sum = 0; - C[i * p + j] = (uint8_t)sum; - } - } - return 0; -} - -static int _matmul_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const float *B, float *C, double scale) { - static int initialized = 0; - if (!initialized) { - matmul_feature_t feat = matmul_get_feature(); -#ifdef __AVX512F__ - if (feat & MATMUL_FLAG_AVX512) - matmul_f32_f32_f32 = matmul_avx512_f32_f32_f32; - else -#endif -#ifdef __AVX2__ - if (feat & MATMUL_FLAG_AVX2) - matmul_f32_f32_f32 = matmul_avx2_f32_f32_f32; - else -#endif - matmul_f32_f32_f32 = matmul_scalar_f32_f32_f32; - initialized = 1; - } - return matmul_f32_f32_f32(m, n, p, A, B, C, scale); -} - -static int _matmul_f32_f32_f64(size_t m, size_t n, size_t p, const float *A, const float *B, double *C, double scale) { - static int initialized = 0; - if (!initialized) { - matmul_feature_t feat = matmul_get_feature(); -#ifdef __AVX512F__ - if (feat & MATMUL_FLAG_AVX512) - matmul_f32_f32_f64 = matmul_avx512_f32_f32_f64; - else -#endif -#ifdef __AVX2__ - if (feat & MATMUL_FLAG_AVX2) - matmul_f32_f32_f64 = matmul_avx2_f32_f32_f64; - else -#endif - matmul_f32_f32_f64 = matmul_scalar_f32_f32_f64; - initialized = 1; - } - return matmul_f32_f32_f64(m, n, p, A, B, C, scale); -} - -static int _matmul_f32_f32_i8(size_t m, size_t n, size_t p, const float *A, const float *B, int8_t *C, double scale) { - static int initialized = 0; - if (!initialized) { - matmul_f32_f32_i8 = matmul_scalar_f32_f32_i8; - initialized = 1; - } - return matmul_f32_f32_i8(m, n, p, A, B, C, scale); -} - -static int _matmul_f32_f32_u8(size_t m, size_t n, size_t p, const float *A, const float *B, uint8_t *C, double scale) { - static int initialized = 0; - if (!initialized) { - matmul_f32_f32_u8 = matmul_scalar_f32_f32_u8; - initialized = 1; - } - return matmul_f32_f32_u8(m, n, p, A, B, C, scale); -} - -static int _matmul_f32_f64_f32(size_t m, size_t n, size_t p, const float *A, const double *B, float *C, double scale) { - static int initialized = 0; - if (!initialized) { - matmul_feature_t feat = matmul_get_feature(); -#ifdef __AVX512F__ - if (feat & MATMUL_FLAG_AVX512) - matmul_f32_f64_f32 = matmul_avx512_f32_f64_f32; - else -#endif -#ifdef __AVX2__ - if (feat & MATMUL_FLAG_AVX2) - matmul_f32_f64_f32 = matmul_avx2_f32_f64_f32; - else -#endif - matmul_f32_f64_f32 = matmul_scalar_f32_f64_f32; - initialized = 1; - } - return matmul_f32_f64_f32(m, n, p, A, B, C, scale); -} - -static int _matmul_f32_f64_f64(size_t m, size_t n, size_t p, const float *A, const double *B, double *C, double scale) { - static int initialized = 0; - if (!initialized) { - matmul_feature_t feat = matmul_get_feature(); -#ifdef __AVX512F__ - if (feat & MATMUL_FLAG_AVX512) - matmul_f32_f64_f64 = matmul_avx512_f32_f64_f64; - else -#endif -#ifdef __AVX2__ - if (feat & MATMUL_FLAG_AVX2) - matmul_f32_f64_f64 = matmul_avx2_f32_f64_f64; - else -#endif - matmul_f32_f64_f64 = matmul_scalar_f32_f64_f64; - initialized = 1; - } - return matmul_f32_f64_f64(m, n, p, A, B, C, scale); -} - -static int _matmul_f32_f64_i8(size_t m, size_t n, size_t p, const float *A, const double *B, int8_t *C, double scale) { - static int initialized = 0; - if (!initialized) { - matmul_f32_f64_i8 = matmul_scalar_f32_f64_i8; - initialized = 1; - } - return matmul_f32_f64_i8(m, n, p, A, B, C, scale); -} - -static int _matmul_f32_f64_u8(size_t m, size_t n, size_t p, const float *A, const double *B, uint8_t *C, double scale) { - static int initialized = 0; - if (!initialized) { - matmul_f32_f64_u8 = matmul_scalar_f32_f64_u8; - initialized = 1; - } - return matmul_f32_f64_u8(m, n, p, A, B, C, scale); -} - -static int _matmul_f32_i8_f32(size_t m, size_t n, size_t p, const float *A, const int8_t *B, float *C, double scale) { - static int initialized = 0; - if (!initialized) { - matmul_f32_i8_f32 = matmul_scalar_f32_i8_f32; - initialized = 1; - } - return matmul_f32_i8_f32(m, n, p, A, B, C, scale); -} - -static int _matmul_f32_i8_f64(size_t m, size_t n, size_t p, const float *A, const int8_t *B, double *C, double scale) { - static int initialized = 0; - if (!initialized) { - matmul_f32_i8_f64 = matmul_scalar_f32_i8_f64; - initialized = 1; - } - return matmul_f32_i8_f64(m, n, p, A, B, C, scale); -} - -static int _matmul_f32_i8_i8(size_t m, size_t n, size_t p, const float *A, const int8_t *B, int8_t *C, double scale) { - static int initialized = 0; - if (!initialized) { - matmul_f32_i8_i8 = matmul_scalar_f32_i8_i8; - initialized = 1; - } - return matmul_f32_i8_i8(m, n, p, A, B, C, scale); -} - -static int _matmul_f32_i8_u8(size_t m, size_t n, size_t p, const float *A, const int8_t *B, uint8_t *C, double scale) { - static int initialized = 0; - if (!initialized) { - matmul_f32_i8_u8 = matmul_scalar_f32_i8_u8; - initialized = 1; - } - return matmul_f32_i8_u8(m, n, p, A, B, C, scale); -} - -static int _matmul_f32_u8_f32(size_t m, size_t n, size_t p, const float *A, const uint8_t *B, float *C, double scale) { - static int initialized = 0; - if (!initialized) { - matmul_f32_u8_f32 = matmul_scalar_f32_u8_f32; - initialized = 1; - } - return matmul_f32_u8_f32(m, n, p, A, B, C, scale); -} - -static int _matmul_f32_u8_f64(size_t m, size_t n, size_t p, const float *A, const uint8_t *B, double *C, double scale) { - static int initialized = 0; - if (!initialized) { - matmul_f32_u8_f64 = matmul_scalar_f32_u8_f64; - initialized = 1; - } - return matmul_f32_u8_f64(m, n, p, A, B, C, scale); -} - -static int _matmul_f32_u8_i8(size_t m, size_t n, size_t p, const float *A, const uint8_t *B, int8_t *C, double scale) { - static int initialized = 0; - if (!initialized) { - matmul_f32_u8_i8 = matmul_scalar_f32_u8_i8; - initialized = 1; - } - return matmul_f32_u8_i8(m, n, p, A, B, C, scale); -} - -static int _matmul_f32_u8_u8(size_t m, size_t n, size_t p, const float *A, const uint8_t *B, uint8_t *C, double scale) { - static int initialized = 0; - if (!initialized) { - matmul_f32_u8_u8 = matmul_scalar_f32_u8_u8; - initialized = 1; - } - return matmul_f32_u8_u8(m, n, p, A, B, C, scale); -} - -static int _matmul_f64_f32_f32(size_t m, size_t n, size_t p, const double *A, const float *B, float *C, double scale) { - static int initialized = 0; - if (!initialized) { - matmul_feature_t feat = matmul_get_feature(); -#ifdef __AVX512F__ - if (feat & MATMUL_FLAG_AVX512) - matmul_f64_f32_f32 = matmul_avx512_f64_f32_f32; - else -#endif -#ifdef __AVX2__ - if (feat & MATMUL_FLAG_AVX2) - matmul_f64_f32_f32 = matmul_avx2_f64_f32_f32; - else -#endif - matmul_f64_f32_f32 = matmul_scalar_f64_f32_f32; - initialized = 1; - } - return matmul_f64_f32_f32(m, n, p, A, B, C, scale); -} - -static int _matmul_f64_f32_f64(size_t m, size_t n, size_t p, const double *A, const float *B, double *C, double scale) { - static int initialized = 0; - if (!initialized) { - matmul_feature_t feat = matmul_get_feature(); -#ifdef __AVX512F__ - if (feat & MATMUL_FLAG_AVX512) - matmul_f64_f32_f64 = matmul_avx512_f64_f32_f64; - else -#endif -#ifdef __AVX2__ - if (feat & MATMUL_FLAG_AVX2) - matmul_f64_f32_f64 = matmul_avx2_f64_f32_f64; - else -#endif - matmul_f64_f32_f64 = matmul_scalar_f64_f32_f64; - initialized = 1; - } - return matmul_f64_f32_f64(m, n, p, A, B, C, scale); -} - -static int _matmul_f64_f32_i8(size_t m, size_t n, size_t p, const double *A, const float *B, int8_t *C, double scale) { - static int initialized = 0; - if (!initialized) { - matmul_f64_f32_i8 = matmul_scalar_f64_f32_i8; - initialized = 1; - } - return matmul_f64_f32_i8(m, n, p, A, B, C, scale); -} - -static int _matmul_f64_f32_u8(size_t m, size_t n, size_t p, const double *A, const float *B, uint8_t *C, double scale) { - static int initialized = 0; - if (!initialized) { - matmul_f64_f32_u8 = matmul_scalar_f64_f32_u8; - initialized = 1; - } - return matmul_f64_f32_u8(m, n, p, A, B, C, scale); -} - -static int _matmul_f64_f64_f32(size_t m, size_t n, size_t p, const double *A, const double *B, float *C, double scale) { - static int initialized = 0; - if (!initialized) { - matmul_feature_t feat = matmul_get_feature(); -#ifdef __AVX512F__ - if (feat & MATMUL_FLAG_AVX512) - matmul_f64_f64_f32 = matmul_avx512_f64_f64_f32; - else -#endif -#ifdef __AVX2__ - if (feat & MATMUL_FLAG_AVX2) - matmul_f64_f64_f32 = matmul_avx2_f64_f64_f32; - else -#endif - matmul_f64_f64_f32 = matmul_scalar_f64_f64_f32; - initialized = 1; - } - return matmul_f64_f64_f32(m, n, p, A, B, C, scale); -} - -static int _matmul_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, const double *B, double *C, - double scale) { - static int initialized = 0; - if (!initialized) { - matmul_feature_t feat = matmul_get_feature(); -#ifdef __AVX512F__ - if (feat & MATMUL_FLAG_AVX512) - matmul_f64_f64_f64 = matmul_avx512_f64_f64_f64; - else -#endif -#ifdef __AVX2__ - if (feat & MATMUL_FLAG_AVX2) - matmul_f64_f64_f64 = matmul_avx2_f64_f64_f64; - else -#endif - matmul_f64_f64_f64 = matmul_scalar_f64_f64_f64; - initialized = 1; - } - return matmul_f64_f64_f64(m, n, p, A, B, C, scale); -} - -static int _matmul_f64_f64_i8(size_t m, size_t n, size_t p, const double *A, const double *B, int8_t *C, double scale) { - static int initialized = 0; - if (!initialized) { - matmul_f64_f64_i8 = matmul_scalar_f64_f64_i8; - initialized = 1; - } - return matmul_f64_f64_i8(m, n, p, A, B, C, scale); -} - -static int _matmul_f64_f64_u8(size_t m, size_t n, size_t p, const double *A, const double *B, uint8_t *C, - double scale) { - static int initialized = 0; - if (!initialized) { - matmul_f64_f64_u8 = matmul_scalar_f64_f64_u8; - initialized = 1; - } - return matmul_f64_f64_u8(m, n, p, A, B, C, scale); -} - -static int _matmul_f64_i8_f32(size_t m, size_t n, size_t p, const double *A, const int8_t *B, float *C, double scale) { - static int initialized = 0; - if (!initialized) { - matmul_f64_i8_f32 = matmul_scalar_f64_i8_f32; - initialized = 1; - } - return matmul_f64_i8_f32(m, n, p, A, B, C, scale); -} - -static int _matmul_f64_i8_f64(size_t m, size_t n, size_t p, const double *A, const int8_t *B, double *C, double scale) { - static int initialized = 0; - if (!initialized) { - matmul_f64_i8_f64 = matmul_scalar_f64_i8_f64; - initialized = 1; - } - return matmul_f64_i8_f64(m, n, p, A, B, C, scale); -} - -static int _matmul_f64_i8_i8(size_t m, size_t n, size_t p, const double *A, const int8_t *B, int8_t *C, double scale) { - static int initialized = 0; - if (!initialized) { - matmul_f64_i8_i8 = matmul_scalar_f64_i8_i8; - initialized = 1; - } - return matmul_f64_i8_i8(m, n, p, A, B, C, scale); -} - -static int _matmul_f64_i8_u8(size_t m, size_t n, size_t p, const double *A, const int8_t *B, uint8_t *C, double scale) { - static int initialized = 0; - if (!initialized) { - matmul_f64_i8_u8 = matmul_scalar_f64_i8_u8; - initialized = 1; - } - return matmul_f64_i8_u8(m, n, p, A, B, C, scale); -} - -static int _matmul_f64_u8_f32(size_t m, size_t n, size_t p, const double *A, const uint8_t *B, float *C, double scale) { - static int initialized = 0; - if (!initialized) { - matmul_f64_u8_f32 = matmul_scalar_f64_u8_f32; - initialized = 1; - } - return matmul_f64_u8_f32(m, n, p, A, B, C, scale); -} - -static int _matmul_f64_u8_f64(size_t m, size_t n, size_t p, const double *A, const uint8_t *B, double *C, - double scale) { - static int initialized = 0; - if (!initialized) { - matmul_f64_u8_f64 = matmul_scalar_f64_u8_f64; - initialized = 1; - } - return matmul_f64_u8_f64(m, n, p, A, B, C, scale); -} - -static int _matmul_f64_u8_i8(size_t m, size_t n, size_t p, const double *A, const uint8_t *B, int8_t *C, double scale) { - static int initialized = 0; - if (!initialized) { - matmul_f64_u8_i8 = matmul_scalar_f64_u8_i8; - initialized = 1; - } - return matmul_f64_u8_i8(m, n, p, A, B, C, scale); -} - -static int _matmul_f64_u8_u8(size_t m, size_t n, size_t p, const double *A, const uint8_t *B, uint8_t *C, - double scale) { - static int initialized = 0; - if (!initialized) { - matmul_f64_u8_u8 = matmul_scalar_f64_u8_u8; - initialized = 1; - } - return matmul_f64_u8_u8(m, n, p, A, B, C, scale); -} - -static int _matmul_i8_f32_f32(size_t m, size_t n, size_t p, const int8_t *A, const float *B, float *C, double scale) { - static int initialized = 0; - if (!initialized) { - matmul_i8_f32_f32 = matmul_scalar_i8_f32_f32; - initialized = 1; - } - return matmul_i8_f32_f32(m, n, p, A, B, C, scale); -} - -static int _matmul_i8_f32_f64(size_t m, size_t n, size_t p, const int8_t *A, const float *B, double *C, double scale) { - static int initialized = 0; - if (!initialized) { - matmul_i8_f32_f64 = matmul_scalar_i8_f32_f64; - initialized = 1; - } - return matmul_i8_f32_f64(m, n, p, A, B, C, scale); -} - -static int _matmul_i8_f32_i8(size_t m, size_t n, size_t p, const int8_t *A, const float *B, int8_t *C, double scale) { - static int initialized = 0; - if (!initialized) { - matmul_i8_f32_i8 = matmul_scalar_i8_f32_i8; - initialized = 1; - } - return matmul_i8_f32_i8(m, n, p, A, B, C, scale); -} - -static int _matmul_i8_f32_u8(size_t m, size_t n, size_t p, const int8_t *A, const float *B, uint8_t *C, double scale) { - static int initialized = 0; - if (!initialized) { - matmul_i8_f32_u8 = matmul_scalar_i8_f32_u8; - initialized = 1; - } - return matmul_i8_f32_u8(m, n, p, A, B, C, scale); -} - -static int _matmul_i8_f64_f32(size_t m, size_t n, size_t p, const int8_t *A, const double *B, float *C, double scale) { - static int initialized = 0; - if (!initialized) { - matmul_i8_f64_f32 = matmul_scalar_i8_f64_f32; - initialized = 1; - } - return matmul_i8_f64_f32(m, n, p, A, B, C, scale); -} - -static int _matmul_i8_f64_f64(size_t m, size_t n, size_t p, const int8_t *A, const double *B, double *C, double scale) { - static int initialized = 0; - if (!initialized) { - matmul_i8_f64_f64 = matmul_scalar_i8_f64_f64; - initialized = 1; - } - return matmul_i8_f64_f64(m, n, p, A, B, C, scale); -} - -static int _matmul_i8_f64_i8(size_t m, size_t n, size_t p, const int8_t *A, const double *B, int8_t *C, double scale) { - static int initialized = 0; - if (!initialized) { - matmul_i8_f64_i8 = matmul_scalar_i8_f64_i8; - initialized = 1; - } - return matmul_i8_f64_i8(m, n, p, A, B, C, scale); -} - -static int _matmul_i8_f64_u8(size_t m, size_t n, size_t p, const int8_t *A, const double *B, uint8_t *C, double scale) { - static int initialized = 0; - if (!initialized) { - matmul_i8_f64_u8 = matmul_scalar_i8_f64_u8; - initialized = 1; - } - return matmul_i8_f64_u8(m, n, p, A, B, C, scale); -} - -static int _matmul_i8_i8_f32(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, float *C, double scale) { - static int initialized = 0; - if (!initialized) { - matmul_feature_t feat = matmul_get_feature(); -#ifdef __AVX512F__ - if (feat & MATMUL_FLAG_AVX512_VNNI) - matmul_i8_i8_f32 = matmul_avx512_i8_i8_f32; - else -#endif -#ifdef __AVX2__ - if (feat & MATMUL_FLAG_AVXVNNI) - matmul_i8_i8_f32 = matmul_avx2_i8_i8_f32; - else -#endif - matmul_i8_i8_f32 = matmul_scalar_i8_i8_f32; - initialized = 1; - } - return matmul_i8_i8_f32(m, n, p, A, B, C, scale); -} - -static int _matmul_i8_i8_f64(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, double *C, double scale) { - static int initialized = 0; - if (!initialized) { - matmul_feature_t feat = matmul_get_feature(); -#ifdef __AVX512F__ - if (feat & MATMUL_FLAG_AVX512_VNNI) - matmul_i8_i8_f64 = matmul_avx512_i8_i8_f64; - else -#endif -#ifdef __AVX2__ - if (feat & MATMUL_FLAG_AVXVNNI) - matmul_i8_i8_f64 = matmul_avx2_i8_i8_f64; - else -#endif - matmul_i8_i8_f64 = matmul_scalar_i8_i8_f64; - initialized = 1; - } - return matmul_i8_i8_f64(m, n, p, A, B, C, scale); -} - -static int _matmul_i8_i8_i8(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, int8_t *C, double scale) { - static int initialized = 0; - if (!initialized) { - matmul_feature_t feat = matmul_get_feature(); -#ifdef __AVX512F__ - if (feat & MATMUL_FLAG_AVX512_VNNI) - matmul_i8_i8_i8 = matmul_avx512_i8_i8_i8; - else -#endif -#ifdef __AVX2__ - if (feat & MATMUL_FLAG_AVXVNNI) - matmul_i8_i8_i8 = matmul_avx2_i8_i8_i8; - else -#endif - matmul_i8_i8_i8 = matmul_scalar_i8_i8_i8; - initialized = 1; - } - return matmul_i8_i8_i8(m, n, p, A, B, C, scale); -} - -static int _matmul_i8_i8_u8(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, uint8_t *C, double scale) { - static int initialized = 0; - if (!initialized) { - matmul_feature_t feat = matmul_get_feature(); -#ifdef __AVX512F__ - if (feat & MATMUL_FLAG_AVX512_VNNI) - matmul_i8_i8_u8 = matmul_avx512_i8_i8_u8; - else -#endif -#ifdef __AVX2__ - if (feat & MATMUL_FLAG_AVXVNNI) - matmul_i8_i8_u8 = matmul_avx2_i8_i8_u8; - else -#endif - matmul_i8_i8_u8 = matmul_scalar_i8_i8_u8; - initialized = 1; - } - return matmul_i8_i8_u8(m, n, p, A, B, C, scale); -} - -static int _matmul_i8_u8_f32(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, float *C, double scale) { - static int initialized = 0; - if (!initialized) { - matmul_feature_t feat = matmul_get_feature(); -#ifdef __AVX512F__ - if (feat & MATMUL_FLAG_AVX512_VNNI) - matmul_i8_u8_f32 = matmul_avx512_i8_u8_f32; - else -#endif -#ifdef __AVX2__ - if (feat & MATMUL_FLAG_AVXVNNI) - matmul_i8_u8_f32 = matmul_avx2_i8_u8_f32; - else -#endif - matmul_i8_u8_f32 = matmul_scalar_i8_u8_f32; - initialized = 1; - } - return matmul_i8_u8_f32(m, n, p, A, B, C, scale); -} - -static int _matmul_i8_u8_f64(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, double *C, double scale) { - static int initialized = 0; - if (!initialized) { - matmul_feature_t feat = matmul_get_feature(); -#ifdef __AVX512F__ - if (feat & MATMUL_FLAG_AVX512_VNNI) - matmul_i8_u8_f64 = matmul_avx512_i8_u8_f64; - else -#endif -#ifdef __AVX2__ - if (feat & MATMUL_FLAG_AVXVNNI) - matmul_i8_u8_f64 = matmul_avx2_i8_u8_f64; - else -#endif - matmul_i8_u8_f64 = matmul_scalar_i8_u8_f64; - initialized = 1; - } - return matmul_i8_u8_f64(m, n, p, A, B, C, scale); -} - -static int _matmul_i8_u8_i8(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, int8_t *C, double scale) { - static int initialized = 0; - if (!initialized) { - matmul_feature_t feat = matmul_get_feature(); -#ifdef __AVX512F__ - if (feat & MATMUL_FLAG_AVX512_VNNI) - matmul_i8_u8_i8 = matmul_avx512_i8_u8_i8; - else -#endif -#ifdef __AVX2__ - if (feat & MATMUL_FLAG_AVXVNNI) - matmul_i8_u8_i8 = matmul_avx2_i8_u8_i8; - else -#endif - matmul_i8_u8_i8 = matmul_scalar_i8_u8_i8; - initialized = 1; - } - return matmul_i8_u8_i8(m, n, p, A, B, C, scale); -} - -static int _matmul_i8_u8_u8(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, uint8_t *C, double scale) { - static int initialized = 0; - if (!initialized) { - matmul_feature_t feat = matmul_get_feature(); -#ifdef __AVX512F__ - if (feat & MATMUL_FLAG_AVX512_VNNI) - matmul_i8_u8_u8 = matmul_avx512_i8_u8_u8; - else -#endif -#ifdef __AVX2__ - if (feat & MATMUL_FLAG_AVXVNNI) - matmul_i8_u8_u8 = matmul_avx2_i8_u8_u8; - else -#endif - matmul_i8_u8_u8 = matmul_scalar_i8_u8_u8; - initialized = 1; - } - return matmul_i8_u8_u8(m, n, p, A, B, C, scale); -} - -static int _matmul_u8_i8_f32(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, float *C, double scale) { - static int initialized = 0; - if (!initialized) { - matmul_feature_t feat = matmul_get_feature(); -#ifdef __AVX512F__ - if (feat & MATMUL_FLAG_AVX512_VNNI) - matmul_u8_i8_f32 = matmul_avx512_u8_i8_f32; - else -#endif -#ifdef __AVX2__ - if (feat & MATMUL_FLAG_AVXVNNI) - matmul_u8_i8_f32 = matmul_avx2_u8_i8_f32; - else -#endif - matmul_u8_i8_f32 = matmul_scalar_u8_i8_f32; - initialized = 1; - } - return matmul_u8_i8_f32(m, n, p, A, B, C, scale); -} - -static int _matmul_u8_i8_f64(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, double *C, double scale) { - static int initialized = 0; - if (!initialized) { - matmul_feature_t feat = matmul_get_feature(); -#ifdef __AVX512F__ - if (feat & MATMUL_FLAG_AVX512_VNNI) - matmul_u8_i8_f64 = matmul_avx512_u8_i8_f64; - else -#endif -#ifdef __AVX2__ - if (feat & MATMUL_FLAG_AVXVNNI) - matmul_u8_i8_f64 = matmul_avx2_u8_i8_f64; - else -#endif - matmul_u8_i8_f64 = matmul_scalar_u8_i8_f64; - initialized = 1; - } - return matmul_u8_i8_f64(m, n, p, A, B, C, scale); -} - -static int _matmul_u8_i8_i8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, int8_t *C, double scale) { - static int initialized = 0; - if (!initialized) { - matmul_feature_t feat = matmul_get_feature(); -#ifdef __AVX512F__ - if (feat & MATMUL_FLAG_AVX512_VNNI) - matmul_u8_i8_i8 = matmul_avx512_u8_i8_i8; - else -#endif -#ifdef __AVX2__ - if (feat & MATMUL_FLAG_AVXVNNI) - matmul_u8_i8_i8 = matmul_avx2_u8_i8_i8; - else -#endif - matmul_u8_i8_i8 = matmul_scalar_u8_i8_i8; - initialized = 1; - } - return matmul_u8_i8_i8(m, n, p, A, B, C, scale); -} - -static int _matmul_u8_i8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, uint8_t *C, double scale) { - static int initialized = 0; - if (!initialized) { - matmul_feature_t feat = matmul_get_feature(); -#ifdef __AVX512F__ - if (feat & MATMUL_FLAG_AVX512_VNNI) - matmul_u8_i8_u8 = matmul_avx512_u8_i8_u8; - else -#endif -#ifdef __AVX2__ - if (feat & MATMUL_FLAG_AVXVNNI) - matmul_u8_i8_u8 = matmul_avx2_u8_i8_u8; - else -#endif - matmul_u8_i8_u8 = matmul_scalar_u8_i8_u8; - initialized = 1; + matmul_u8_i8_u8 = matmul_scalar_u8_i8_u8; + initialized = 1; } return matmul_u8_i8_u8(m, n, p, A, B, C, scale); } -static int _matmul_u8_u8_f32(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, float *C, double scale) { - static int initialized = 0; - if (!initialized) { - matmul_feature_t feat = matmul_get_feature(); -#ifdef __AVX512F__ - if (feat & MATMUL_FLAG_AVX512_VNNI) - matmul_u8_u8_f32 = matmul_avx512_u8_u8_f32; - else -#endif -#ifdef __AVX2__ - if (feat & MATMUL_FLAG_AVXVNNI) - matmul_u8_u8_f32 = matmul_avx2_u8_u8_f32; - else -#endif - matmul_u8_u8_f32 = matmul_scalar_u8_u8_f32; - initialized = 1; - } - return matmul_u8_u8_f32(m, n, p, A, B, C, scale); -} - -static int _matmul_u8_u8_f64(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, double *C, - double scale) { - static int initialized = 0; - if (!initialized) { - matmul_feature_t feat = matmul_get_feature(); -#ifdef __AVX512F__ - if (feat & MATMUL_FLAG_AVX512_VNNI) - matmul_u8_u8_f64 = matmul_avx512_u8_u8_f64; - else -#endif -#ifdef __AVX2__ - if (feat & MATMUL_FLAG_AVXVNNI) - matmul_u8_u8_f64 = matmul_avx2_u8_u8_f64; - else -#endif - matmul_u8_u8_f64 = matmul_scalar_u8_u8_f64; - initialized = 1; - } - return matmul_u8_u8_f64(m, n, p, A, B, C, scale); -} - -static int _matmul_u8_u8_i8(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, int8_t *C, double scale) { - static int initialized = 0; - if (!initialized) { - matmul_feature_t feat = matmul_get_feature(); -#ifdef __AVX512F__ - if (feat & MATMUL_FLAG_AVX512_VNNI) - matmul_u8_u8_i8 = matmul_avx512_u8_u8_i8; - else -#endif -#ifdef __AVX2__ - if (feat & MATMUL_FLAG_AVXVNNI) - matmul_u8_u8_i8 = matmul_avx2_u8_u8_i8; - else -#endif - matmul_u8_u8_i8 = matmul_scalar_u8_u8_i8; - initialized = 1; - } - return matmul_u8_u8_i8(m, n, p, A, B, C, scale); -} - -static int _matmul_u8_u8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, uint8_t *C, - double scale) { - static int initialized = 0; - if (!initialized) { - matmul_feature_t feat = matmul_get_feature(); -#ifdef __AVX512F__ - if (feat & MATMUL_FLAG_AVX512_VNNI) - matmul_u8_u8_u8 = matmul_avx512_u8_u8_u8; - else -#endif -#ifdef __AVX2__ - if (feat & MATMUL_FLAG_AVXVNNI) - matmul_u8_u8_u8 = matmul_avx2_u8_u8_u8; - else -#endif - matmul_u8_u8_u8 = matmul_scalar_u8_u8_u8; - initialized = 1; - } - return matmul_u8_u8_u8(m, n, p, A, B, C, scale); -} - -static int _matmul_u8_f32_f32(size_t m, size_t n, size_t p, const uint8_t *A, const float *B, float *C, double scale) { - static int initialized = 0; - if (!initialized) { - matmul_u8_f32_f32 = matmul_scalar_u8_f32_f32; - initialized = 1; - } - return matmul_u8_f32_f32(m, n, p, A, B, C, scale); -} - -static int _matmul_u8_f32_f64(size_t m, size_t n, size_t p, const uint8_t *A, const float *B, double *C, double scale) { - static int initialized = 0; - if (!initialized) { - matmul_u8_f32_f64 = matmul_scalar_u8_f32_f64; - initialized = 1; - } - return matmul_u8_f32_f64(m, n, p, A, B, C, scale); -} - -static int _matmul_u8_f32_i8(size_t m, size_t n, size_t p, const uint8_t *A, const float *B, int8_t *C, double scale) { - static int initialized = 0; - if (!initialized) { - matmul_u8_f32_i8 = matmul_scalar_u8_f32_i8; - initialized = 1; - } - return matmul_u8_f32_i8(m, n, p, A, B, C, scale); -} - -static int _matmul_u8_f32_u8(size_t m, size_t n, size_t p, const uint8_t *A, const float *B, uint8_t *C, double scale) { - static int initialized = 0; - if (!initialized) { - matmul_u8_f32_u8 = matmul_scalar_u8_f32_u8; - initialized = 1; - } - return matmul_u8_f32_u8(m, n, p, A, B, C, scale); -} - -static int _matmul_u8_f64_f32(size_t m, size_t n, size_t p, const uint8_t *A, const double *B, float *C, double scale) { - static int initialized = 0; - if (!initialized) { - matmul_u8_f64_f32 = matmul_scalar_u8_f64_f32; - initialized = 1; - } - return matmul_u8_f64_f32(m, n, p, A, B, C, scale); -} - -static int _matmul_u8_f64_f64(size_t m, size_t n, size_t p, const uint8_t *A, const double *B, double *C, - double scale) { - static int initialized = 0; - if (!initialized) { - matmul_u8_f64_f64 = matmul_scalar_u8_f64_f64; - initialized = 1; - } - return matmul_u8_f64_f64(m, n, p, A, B, C, scale); -} - -static int _matmul_u8_f64_i8(size_t m, size_t n, size_t p, const uint8_t *A, const double *B, int8_t *C, double scale) { - static int initialized = 0; - if (!initialized) { - matmul_u8_f64_i8 = matmul_scalar_u8_f64_i8; - initialized = 1; - } - return matmul_u8_f64_i8(m, n, p, A, B, C, scale); -} - -static int _matmul_u8_f64_u8(size_t m, size_t n, size_t p, const uint8_t *A, const double *B, uint8_t *C, - double scale) { - static int initialized = 0; - if (!initialized) { - matmul_u8_f64_u8 = matmul_scalar_u8_f64_u8; - initialized = 1; - } - return matmul_u8_f64_u8(m, n, p, A, B, C, scale); -} - -int (*matmul_f32_f32_f32)(size_t, size_t, size_t, const float *, const float *, float *, double) = _matmul_f32_f32_f32; -int (*matmul_f32_f32_f64)(size_t, size_t, size_t, const float *, const float *, double *, double) = _matmul_f32_f32_f64; -int (*matmul_f32_f32_i8)(size_t, size_t, size_t, const float *, const float *, int8_t *, double) = _matmul_f32_f32_i8; -int (*matmul_f32_f32_u8)(size_t, size_t, size_t, const float *, const float *, uint8_t *, double) = _matmul_f32_f32_u8; -int (*matmul_f32_f64_f32)(size_t, size_t, size_t, const float *, const double *, float *, double) = _matmul_f32_f64_f32; -int (*matmul_f32_f64_f64)(size_t, size_t, size_t, const float *, const double *, double *, - double) = _matmul_f32_f64_f64; -int (*matmul_f32_f64_i8)(size_t, size_t, size_t, const float *, const double *, int8_t *, double) = _matmul_f32_f64_i8; -int (*matmul_f32_f64_u8)(size_t, size_t, size_t, const float *, const double *, uint8_t *, double) = _matmul_f32_f64_u8; -int (*matmul_f32_i8_f32)(size_t, size_t, size_t, const float *, const int8_t *, float *, double) = _matmul_f32_i8_f32; -int (*matmul_f32_i8_f64)(size_t, size_t, size_t, const float *, const int8_t *, double *, double) = _matmul_f32_i8_f64; -int (*matmul_f32_i8_i8)(size_t, size_t, size_t, const float *, const int8_t *, int8_t *, double) = _matmul_f32_i8_i8; -int (*matmul_f32_i8_u8)(size_t, size_t, size_t, const float *, const int8_t *, uint8_t *, double) = _matmul_f32_i8_u8; -int (*matmul_f32_u8_f32)(size_t, size_t, size_t, const float *, const uint8_t *, float *, double) = _matmul_f32_u8_f32; -int (*matmul_f32_u8_f64)(size_t, size_t, size_t, const float *, const uint8_t *, double *, double) = _matmul_f32_u8_f64; -int (*matmul_f32_u8_i8)(size_t, size_t, size_t, const float *, const uint8_t *, int8_t *, double) = _matmul_f32_u8_i8; -int (*matmul_f32_u8_u8)(size_t, size_t, size_t, const float *, const uint8_t *, uint8_t *, double) = _matmul_f32_u8_u8; -int (*matmul_f64_f32_f32)(size_t, size_t, size_t, const double *, const float *, float *, double) = _matmul_f64_f32_f32; -int (*matmul_f64_f32_f64)(size_t, size_t, size_t, const double *, const float *, double *, - double) = _matmul_f64_f32_f64; -int (*matmul_f64_f32_i8)(size_t, size_t, size_t, const double *, const float *, int8_t *, double) = _matmul_f64_f32_i8; -int (*matmul_f64_f32_u8)(size_t, size_t, size_t, const double *, const float *, uint8_t *, double) = _matmul_f64_f32_u8; -int (*matmul_f64_f64_f32)(size_t, size_t, size_t, const double *, const double *, float *, - double) = _matmul_f64_f64_f32; -int (*matmul_f64_f64_f64)(size_t, size_t, size_t, const double *, const double *, double *, - double) = _matmul_f64_f64_f64; -int (*matmul_f64_f64_i8)(size_t, size_t, size_t, const double *, const double *, int8_t *, double) = _matmul_f64_f64_i8; -int (*matmul_f64_f64_u8)(size_t, size_t, size_t, const double *, const double *, uint8_t *, - double) = _matmul_f64_f64_u8; -int (*matmul_f64_i8_f32)(size_t, size_t, size_t, const double *, const int8_t *, float *, double) = _matmul_f64_i8_f32; -int (*matmul_f64_i8_f64)(size_t, size_t, size_t, const double *, const int8_t *, double *, double) = _matmul_f64_i8_f64; -int (*matmul_f64_i8_i8)(size_t, size_t, size_t, const double *, const int8_t *, int8_t *, double) = _matmul_f64_i8_i8; -int (*matmul_f64_i8_u8)(size_t, size_t, size_t, const double *, const int8_t *, uint8_t *, double) = _matmul_f64_i8_u8; -int (*matmul_f64_u8_f32)(size_t, size_t, size_t, const double *, const uint8_t *, float *, double) = _matmul_f64_u8_f32; -int (*matmul_f64_u8_f64)(size_t, size_t, size_t, const double *, const uint8_t *, double *, - double) = _matmul_f64_u8_f64; -int (*matmul_f64_u8_i8)(size_t, size_t, size_t, const double *, const uint8_t *, int8_t *, double) = _matmul_f64_u8_i8; -int (*matmul_f64_u8_u8)(size_t, size_t, size_t, const double *, const uint8_t *, uint8_t *, double) = _matmul_f64_u8_u8; -int (*matmul_i8_f32_f32)(size_t, size_t, size_t, const int8_t *, const float *, float *, double) = _matmul_i8_f32_f32; -int (*matmul_i8_f32_f64)(size_t, size_t, size_t, const int8_t *, const float *, double *, double) = _matmul_i8_f32_f64; -int (*matmul_i8_f32_i8)(size_t, size_t, size_t, const int8_t *, const float *, int8_t *, double) = _matmul_i8_f32_i8; -int (*matmul_i8_f32_u8)(size_t, size_t, size_t, const int8_t *, const float *, uint8_t *, double) = _matmul_i8_f32_u8; -int (*matmul_i8_f64_f32)(size_t, size_t, size_t, const int8_t *, const double *, float *, double) = _matmul_i8_f64_f32; -int (*matmul_i8_f64_f64)(size_t, size_t, size_t, const int8_t *, const double *, double *, double) = _matmul_i8_f64_f64; -int (*matmul_i8_f64_i8)(size_t, size_t, size_t, const int8_t *, const double *, int8_t *, double) = _matmul_i8_f64_i8; -int (*matmul_i8_f64_u8)(size_t, size_t, size_t, const int8_t *, const double *, uint8_t *, double) = _matmul_i8_f64_u8; -int (*matmul_i8_i8_f32)(size_t, size_t, size_t, const int8_t *, const int8_t *, float *, double) = _matmul_i8_i8_f32; -int (*matmul_i8_i8_f64)(size_t, size_t, size_t, const int8_t *, const int8_t *, double *, double) = _matmul_i8_i8_f64; -int (*matmul_i8_i8_i8)(size_t, size_t, size_t, const int8_t *, const int8_t *, int8_t *, double) = _matmul_i8_i8_i8; -int (*matmul_i8_i8_u8)(size_t, size_t, size_t, const int8_t *, const int8_t *, uint8_t *, double) = _matmul_i8_i8_u8; -int (*matmul_i8_u8_f32)(size_t, size_t, size_t, const int8_t *, const uint8_t *, float *, double) = _matmul_i8_u8_f32; -int (*matmul_i8_u8_f64)(size_t, size_t, size_t, const int8_t *, const uint8_t *, double *, double) = _matmul_i8_u8_f64; -int (*matmul_i8_u8_i8)(size_t, size_t, size_t, const int8_t *, const uint8_t *, int8_t *, double) = _matmul_i8_u8_i8; -int (*matmul_i8_u8_u8)(size_t, size_t, size_t, const int8_t *, const uint8_t *, uint8_t *, double) = _matmul_i8_u8_u8; -int (*matmul_u8_f32_f32)(size_t, size_t, size_t, const uint8_t *, const float *, float *, double) = _matmul_u8_f32_f32; -int (*matmul_u8_f32_f64)(size_t, size_t, size_t, const uint8_t *, const float *, double *, double) = _matmul_u8_f32_f64; -int (*matmul_u8_f32_i8)(size_t, size_t, size_t, const uint8_t *, const float *, int8_t *, double) = _matmul_u8_f32_i8; -int (*matmul_u8_f32_u8)(size_t, size_t, size_t, const uint8_t *, const float *, uint8_t *, double) = _matmul_u8_f32_u8; -int (*matmul_u8_f64_f32)(size_t, size_t, size_t, const uint8_t *, const double *, float *, double) = _matmul_u8_f64_f32; -int (*matmul_u8_f64_f64)(size_t, size_t, size_t, const uint8_t *, const double *, double *, - double) = _matmul_u8_f64_f64; -int (*matmul_u8_f64_i8)(size_t, size_t, size_t, const uint8_t *, const double *, int8_t *, double) = _matmul_u8_f64_i8; -int (*matmul_u8_f64_u8)(size_t, size_t, size_t, const uint8_t *, const double *, uint8_t *, double) = _matmul_u8_f64_u8; -int (*matmul_u8_i8_f32)(size_t, size_t, size_t, const uint8_t *, const int8_t *, float *, double) = _matmul_u8_i8_f32; -int (*matmul_u8_i8_f64)(size_t, size_t, size_t, const uint8_t *, const int8_t *, double *, double) = _matmul_u8_i8_f64; -int (*matmul_u8_i8_i8)(size_t, size_t, size_t, const uint8_t *, const int8_t *, int8_t *, double) = _matmul_u8_i8_i8; -int (*matmul_u8_i8_u8)(size_t, size_t, size_t, const uint8_t *, const int8_t *, uint8_t *, double) = _matmul_u8_i8_u8; -int (*matmul_u8_u8_f32)(size_t, size_t, size_t, const uint8_t *, const uint8_t *, float *, double) = _matmul_u8_u8_f32; -int (*matmul_u8_u8_f64)(size_t, size_t, size_t, const uint8_t *, const uint8_t *, double *, double) = _matmul_u8_u8_f64; -int (*matmul_u8_u8_i8)(size_t, size_t, size_t, const uint8_t *, const uint8_t *, int8_t *, double) = _matmul_u8_u8_i8; -int (*matmul_u8_u8_u8)(size_t, size_t, size_t, const uint8_t *, const uint8_t *, uint8_t *, double) = _matmul_u8_u8_u8; +int (*matmul_u8_i8_u8)(size_t, size_t, size_t, const uint8_t *, const int8_t *, uint8_t *, double) = _matmul_u8_i8_u8; diff --git a/test/benchmark.c b/test/benchmark.c @@ -1,148 +1,99 @@ -#include "nemequ/munit.h" -#include "../src/matmul.h" +#define _POSIX_C_SOURCE 199309L +#include <stdint.h> #include <stdio.h> #include <stdlib.h> +#include <string.h> #include <time.h> -static float *alloc_f32(size_t size) { - return (float*)calloc(size, sizeof(float)); -} +#include "../src/matmul.h" -static double *alloc_f64(size_t size) { - return (double*)calloc(size, sizeof(double)); -} +#define RUNS 100 -static int8_t *alloc_i8(size_t size) { - return (int8_t*)calloc(size, sizeof(int8_t)); +static struct timespec timespec_now(void) { + struct timespec ts; + clock_gettime(CLOCK_MONOTONIC, &ts); + return ts; } -static uint8_t *alloc_u8(size_t size) { - return (uint8_t*)calloc(size, sizeof(uint8_t)); +static double timespec_diff_ms(struct timespec start, struct timespec end) { + return (end.tv_sec - start.tv_sec) * 1000.0 + (end.tv_nsec - start.tv_nsec) / 1e6; } -static double get_time_ms(void) { - struct timespec ts; - clock_gettime(CLOCK_MONOTONIC, &ts); - return ts.tv_sec * 1000.0 + ts.tv_nsec / 1000000.0; +static int compare_double(const void *a, const void *b) { + double da = *(const double *)a; + double db = *(const double *)b; + return (da > db) - (da < db); } -static MunitResult test_bench_f32_128x128(const MunitParameter *params, void *data) { - (void)params; (void)data; - - size_t m = 128, n = 128, p = 128; - size_t size = m * n; - float *A = alloc_f32(size); - float *B = alloc_f32(size); - float *C = alloc_f32(m * p); - - for (size_t i = 0; i < size; i++) { - A[i] = (float)(i % 100) * 0.1f; - B[i] = (float)((i + 7) % 100) * 0.1f; - } - - double start = get_time_ms(); - for (int iter = 0; iter < 10; iter++) { - matmul(m, n, p, A, B, C); - } - double elapsed = get_time_ms() - start; - - matmul_feature_t feat = matmul_get_feature(); - printf("f32 128x128: %.2f ms (feature: %s)\n", elapsed / 10.0, matmul_get_feature_name(feat)); - - free(A); free(B); free(C); - return MUNIT_OK; +static double percentile(double *sorted, double p, int n) { + double idx = p / 100.0 * (n - 1); + int lo = (int)idx; + int hi = lo + 1; + double frac = idx - lo; + if (hi >= n) hi = n - 1; + return sorted[lo] * (1.0 - frac) + sorted[hi] * frac; } -static MunitResult test_bench_f64_128x128(const MunitParameter *params, void *data) { - (void)params; (void)data; - - size_t m = 128, n = 128, p = 128; - size_t size = m * n; - double *A = alloc_f64(size); - double *B = alloc_f64(size); - double *C = alloc_f64(m * p); - - for (size_t i = 0; i < size; i++) { - A[i] = (double)(i % 100) * 0.1; - B[i] = (double)((i + 7) % 100) * 0.1; - } - - double start = get_time_ms(); - for (int iter = 0; iter < 10; iter++) { - matmul_f64(m, n, p, A, B, C); - } - double elapsed = get_time_ms() - start; - - printf("f64 128x128: %.2f ms\n", elapsed / 10.0); - - free(A); free(B); free(C); - return MUNIT_OK; -} +static void bench(size_t m, size_t n, size_t p, int runs) { + uint8_t *A = malloc(m * n); + int8_t *B = malloc(n * p); + uint8_t *C = malloc(m * p); + uint8_t *Cwarm = malloc(m * p); + double times[RUNS]; -static MunitResult test_bench_i8_128x128(const MunitParameter *params, void *data) { - (void)params; (void)data; - - size_t m = 128, n = 128, p = 128; - size_t size = m * n; - int8_t *A = alloc_i8(size); - int8_t *B = alloc_i8(size); - int8_t *C = alloc_i8(m * p); - - for (size_t i = 0; i < size; i++) { - A[i] = (int8_t)(i % 256 - 128); - B[i] = (int8_t)((i + 7) % 256 - 128); - } - - double start = get_time_ms(); - for (int iter = 0; iter < 10; iter++) { - matmul_i8(m, n, p, A, B, C); - } - double elapsed = get_time_ms() - start; - - printf("i8 128x128: %.2f ms\n", elapsed / 10.0); - - free(A); free(B); free(C); - return MUNIT_OK; -} + if (!A || !B || !C || !Cwarm) { + fprintf(stderr, "OOM for %zu x %zu\n", m, n); + free(A); + free(B); + free(C); + free(Cwarm); + return; + } + + for (size_t i = 0; i < m * n; i++) A[i] = (uint8_t)(rand() % 256); + for (size_t i = 0; i < n * p; i++) B[i] = (int8_t)(rand() % 256); + memset(C, 0, m * p); + memset(Cwarm, 0, m * p); + + matmul_u8_i8_u8(m, n, p, A, B, Cwarm, 0.0); + + int actual_runs = runs; + if (m >= 4096) actual_runs = 3; -static MunitResult test_bench_u8_128x128(const MunitParameter *params, void *data) { - (void)params; (void)data; - - size_t m = 128, n = 128, p = 128; - size_t size = m * n; - uint8_t *A = alloc_u8(size); - uint8_t *B = alloc_u8(size); - uint8_t *C = alloc_u8(m * p); - - for (size_t i = 0; i < size; i++) { - A[i] = (uint8_t)(i % 256); - B[i] = (uint8_t)((i + 7) % 256); - } - - double start = get_time_ms(); - for (int iter = 0; iter < 10; iter++) { - matmul_u8(m, n, p, A, B, C); - } - double elapsed = get_time_ms() - start; - - printf("u8 128x128: %.2f ms\n", elapsed / 10.0); - - free(A); free(B); free(C); - return MUNIT_OK; + for (int r = 0; r < actual_runs; r++) { + memset(C, 0, m * p); + struct timespec start = timespec_now(); + matmul_u8_i8_u8(m, n, p, A, B, C, 0.0); + struct timespec end = timespec_now(); + times[r] = timespec_diff_ms(start, end); + } + + qsort(times, actual_runs, sizeof(double), compare_double); + + double gflops = 2.0 * m * n * p / (percentile(times, 50, actual_runs) * 1e6); + + printf("%8zu x %8zu | %6.2f | %6.2f | %6.2f | %6.2f | %6.2f | %8.1f\n", m, n, percentile(times, 1, actual_runs), + percentile(times, 5, actual_runs), percentile(times, 50, actual_runs), percentile(times, 95, actual_runs), + percentile(times, 99, actual_runs), gflops); } -static MunitTest tests[] = { - {"/f32-128x128", test_bench_f32_128x128, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, - {"/f64-128x128", test_bench_f64_128x128, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, - {"/i8-128x128", test_bench_i8_128x128, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, - {"/u8-128x128", test_bench_u8_128x128, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, - {NULL, NULL, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL} -}; - -static const MunitSuite suite = { - {"/benchmark", tests, NULL, 1, MUNIT_SUITE_OPTION_NONE}, -}; - -int main(int argc, char *argv[MUNIT_ARRAY_PARAM(argc)]) { - return munit_suite_main(&suite, NULL, argc, argv); -} -\ No newline at end of file +int main(void) { + srand(42); + + printf("Benchmark: u8_i8_u8 matmul, %d runs per size\n", RUNS); + printf("--------------------------------------------------------------\n"); + printf("%8s | %8s | %8s | %8s | %8s | %8s | %8s\n", "M x N", "1% (ms)", "5% (ms)", "50% (ms)", "95% (ms)", "99% (ms)", + "GFLOPS"); + printf("--------------------------------------------------------------\n"); + + bench(16, 16, 16, RUNS); + bench(64, 64, 64, RUNS); + bench(256, 256, 256, RUNS); + bench(1024, 1024, 1024, RUNS); + bench(4096, 4096, 4096, RUNS); + // bench(16384, 16384, 16384, RUNS); + + printf("--------------------------------------------------------------\n"); + + return 0; +} diff --git a/test/test_matmul.c b/test/test_matmul.c @@ -1,324 +1,146 @@ -#include "nemequ/munit.h" -#include "../src/matmul.h" +#include <math.h> +#include <stdint.h> #include <stdio.h> #include <stdlib.h> #include <string.h> -#include <math.h> -#include <stdint.h> -#define f32 float -#define f64 double -#define i8 int8_t -#define u8 uint8_t +#include "../src/matmul.h" +#include "nemequ/munit.h" +#include "test_matmul_simd.h" -static void ref_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const float *B, float *C) { - for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { float s = 0; for (size_t k = 0; k < n; k++) s += A[i*n+k] * B[k*p+j]; C[i*p+j] = s; } -} -static void ref_f32_f32_f64(size_t m, size_t n, size_t p, const float *A, const float *B, double *C) { - for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { double s = 0; for (size_t k = 0; k < n; k++) s += (double)A[i*n+k] * (double)B[k*p+j]; C[i*p+j] = s; } -} -static void ref_f32_f32_i8(size_t m, size_t n, size_t p, const float *A, const float *B, int8_t *C) { - for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { int s = 0; for (size_t k = 0; k < n; k++) s += (int)A[i*n+k] * (int)B[k*p+j]; C[i*p+j] = (s > 127) ? 127 : (s < -128 ? -128 : s); } -} -static void ref_f32_f32_u8(size_t m, size_t n, size_t p, const float *A, const float *B, uint8_t *C) { - for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { int s = 0; for (size_t k = 0; k < n; k++) s += (int)A[i*n+k] * (int)B[k*p+j]; C[i*p+j] = (s > 255) ? 255 : (s < 0 ? 0 : s); } -} -static void ref_f32_f64_f32(size_t m, size_t n, size_t p, const float *A, const double *B, float *C) { - for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { float s = 0; for (size_t k = 0; k < n; k++) s += (float)A[i*n+k] * (float)B[k*p+j]; C[i*p+j] = s; } -} -static void ref_f32_f64_f64(size_t m, size_t n, size_t p, const float *A, const double *B, double *C) { - for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { double s = 0; for (size_t k = 0; k < n; k++) s += (double)A[i*n+k] * B[k*p+j]; C[i*p+j] = s; } -} -static void ref_f32_f64_i8(size_t m, size_t n, size_t p, const float *A, const double *B, int8_t *C) { - for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { int s = 0; for (size_t k = 0; k < n; k++) s += (int)A[i*n+k] * (int)B[k*p+j]; C[i*p+j] = (s > 127) ? 127 : (s < -128 ? -128 : s); } -} -static void ref_f32_f64_u8(size_t m, size_t n, size_t p, const float *A, const double *B, uint8_t *C) { - for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { int s = 0; for (size_t k = 0; k < n; k++) s += (int)A[i*n+k] * (int)B[k*p+j]; C[i*p+j] = (s > 255) ? 255 : (s < 0 ? 0 : s); } -} -static void ref_f32_i8_f32(size_t m, size_t n, size_t p, const float *A, const int8_t *B, float *C) { - for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { float s = 0; for (size_t k = 0; k < n; k++) s += (float)A[i*n+k] * (float)B[k*p+j]; C[i*p+j] = s; } -} -static void ref_f32_i8_f64(size_t m, size_t n, size_t p, const float *A, const int8_t *B, double *C) { - for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { double s = 0; for (size_t k = 0; k < n; k++) s += (double)A[i*n+k] * (double)B[k*p+j]; C[i*p+j] = s; } -} -static void ref_f32_i8_i8(size_t m, size_t n, size_t p, const float *A, const int8_t *B, int8_t *C) { - for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { int s = 0; for (size_t k = 0; k < n; k++) s += (int)A[i*n+k] * (int)B[k*p+j]; C[i*p+j] = (s > 127) ? 127 : (s < -128 ? -128 : s); } -} -static void ref_f32_i8_u8(size_t m, size_t n, size_t p, const float *A, const int8_t *B, uint8_t *C) { - for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { int s = 0; for (size_t k = 0; k < n; k++) s += (int)A[i*n+k] * (int)B[k*p+j]; C[i*p+j] = (s > 255) ? 255 : (s < 0 ? 0 : s); } -} -static void ref_f32_u8_f32(size_t m, size_t n, size_t p, const float *A, const uint8_t *B, float *C) { - for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { float s = 0; for (size_t k = 0; k < n; k++) s += (float)A[i*n+k] * (float)B[k*p+j]; C[i*p+j] = s; } -} -static void ref_f32_u8_f64(size_t m, size_t n, size_t p, const float *A, const uint8_t *B, double *C) { - for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { double s = 0; for (size_t k = 0; k < n; k++) s += (double)A[i*n+k] * (double)B[k*p+j]; C[i*p+j] = s; } -} -static void ref_f32_u8_i8(size_t m, size_t n, size_t p, const float *A, const uint8_t *B, int8_t *C) { - for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { int s = 0; for (size_t k = 0; k < n; k++) s += (int)A[i*n+k] * (int)B[k*p+j]; C[i*p+j] = (s > 127) ? 127 : (s < -128 ? -128 : s); } -} -static void ref_f32_u8_u8(size_t m, size_t n, size_t p, const float *A, const uint8_t *B, uint8_t *C) { - for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { int s = 0; for (size_t k = 0; k < n; k++) s += (int)A[i*n+k] * (int)B[k*p+j]; C[i*p+j] = (s > 255) ? 255 : (s < 0 ? 0 : s); } -} -static void ref_f64_f32_f32(size_t m, size_t n, size_t p, const double *A, const float *B, float *C) { - for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { float s = 0; for (size_t k = 0; k < n; k++) s += (float)A[i*n+k] * (float)B[k*p+j]; C[i*p+j] = s; } -} -static void ref_f64_f32_f64(size_t m, size_t n, size_t p, const double *A, const float *B, double *C) { - for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { double s = 0; for (size_t k = 0; k < n; k++) s += A[i*n+k] * (double)B[k*p+j]; C[i*p+j] = s; } -} -static void ref_f64_f32_i8(size_t m, size_t n, size_t p, const double *A, const float *B, int8_t *C) { - for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { int s = 0; for (size_t k = 0; k < n; k++) s += (int)A[i*n+k] * (int)B[k*p+j]; C[i*p+j] = (s > 127) ? 127 : (s < -128 ? -128 : s); } -} -static void ref_f64_f32_u8(size_t m, size_t n, size_t p, const double *A, const float *B, uint8_t *C) { - for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { int s = 0; for (size_t k = 0; k < n; k++) s += (int)A[i*n+k] * (int)B[k*p+j]; C[i*p+j] = (s > 255) ? 255 : (s < 0 ? 0 : s); } -} -static void ref_f64_f64_f32(size_t m, size_t n, size_t p, const double *A, const double *B, float *C) { - for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { float s = 0; for (size_t k = 0; k < n; k++) s += (float)A[i*n+k] * (float)B[k*p+j]; C[i*p+j] = s; } -} -static void ref_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, const double *B, double *C) { - for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { double s = 0; for (size_t k = 0; k < n; k++) s += A[i*n+k] * B[k*p+j]; C[i*p+j] = s; } -} -static void ref_f64_f64_i8(size_t m, size_t n, size_t p, const double *A, const double *B, int8_t *C) { - for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { int s = 0; for (size_t k = 0; k < n; k++) s += (int)A[i*n+k] * (int)B[k*p+j]; C[i*p+j] = (s > 127) ? 127 : (s < -128 ? -128 : s); } -} -static void ref_f64_f64_u8(size_t m, size_t n, size_t p, const double *A, const double *B, uint8_t *C) { - for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { int s = 0; for (size_t k = 0; k < n; k++) s += (int)A[i*n+k] * (int)B[k*p+j]; C[i*p+j] = (s > 255) ? 255 : (s < 0 ? 0 : s); } -} -static void ref_f64_i8_f32(size_t m, size_t n, size_t p, const double *A, const int8_t *B, float *C) { - for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { float s = 0; for (size_t k = 0; k < n; k++) s += (float)A[i*n+k] * (float)B[k*p+j]; C[i*p+j] = s; } -} -static void ref_f64_i8_f64(size_t m, size_t n, size_t p, const double *A, const int8_t *B, double *C) { - for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { double s = 0; for (size_t k = 0; k < n; k++) s += A[i*n+k] * (double)B[k*p+j]; C[i*p+j] = s; } -} -static void ref_f64_i8_i8(size_t m, size_t n, size_t p, const double *A, const int8_t *B, int8_t *C) { - for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { int s = 0; for (size_t k = 0; k < n; k++) s += (int)A[i*n+k] * (int)B[k*p+j]; C[i*p+j] = (s > 127) ? 127 : (s < -128 ? -128 : s); } -} -static void ref_f64_i8_u8(size_t m, size_t n, size_t p, const double *A, const int8_t *B, uint8_t *C) { - for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { int s = 0; for (size_t k = 0; k < n; k++) s += (int)A[i*n+k] * (int)B[k*p+j]; C[i*p+j] = (s > 255) ? 255 : (s < 0 ? 0 : s); } -} -static void ref_f64_u8_f32(size_t m, size_t n, size_t p, const double *A, const uint8_t *B, float *C) { - for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { float s = 0; for (size_t k = 0; k < n; k++) s += (float)A[i*n+k] * (float)B[k*p+j]; C[i*p+j] = s; } -} -static void ref_f64_u8_f64(size_t m, size_t n, size_t p, const double *A, const uint8_t *B, double *C) { - for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { double s = 0; for (size_t k = 0; k < n; k++) s += A[i*n+k] * (double)B[k*p+j]; C[i*p+j] = s; } -} -static void ref_f64_u8_i8(size_t m, size_t n, size_t p, const double *A, const uint8_t *B, int8_t *C) { - for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { int s = 0; for (size_t k = 0; k < n; k++) s += (int)A[i*n+k] * (int)B[k*p+j]; C[i*p+j] = (s > 127) ? 127 : (s < -128 ? -128 : s); } -} -static void ref_f64_u8_u8(size_t m, size_t n, size_t p, const double *A, const uint8_t *B, uint8_t *C) { - for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { int s = 0; for (size_t k = 0; k < n; k++) s += (int)A[i*n+k] * (int)B[k*p+j]; C[i*p+j] = (s > 255) ? 255 : (s < 0 ? 0 : s); } -} -static void ref_i8_f32_f32(size_t m, size_t n, size_t p, const int8_t *A, const float *B, float *C) { - for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { float s = 0; for (size_t k = 0; k < n; k++) s += (float)A[i*n+k] * B[k*p+j]; C[i*p+j] = s; } -} -static void ref_i8_f32_f64(size_t m, size_t n, size_t p, const int8_t *A, const float *B, double *C) { - for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { double s = 0; for (size_t k = 0; k < n; k++) s += (double)A[i*n+k] * (double)B[k*p+j]; C[i*p+j] = s; } -} -static void ref_i8_f32_i8(size_t m, size_t n, size_t p, const int8_t *A, const float *B, int8_t *C) { - for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { int s = 0; for (size_t k = 0; k < n; k++) s += (int)A[i*n+k] * (int)B[k*p+j]; C[i*p+j] = (s > 127) ? 127 : (s < -128 ? -128 : s); } -} -static void ref_i8_f32_u8(size_t m, size_t n, size_t p, const int8_t *A, const float *B, uint8_t *C) { - for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { int s = 0; for (size_t k = 0; k < n; k++) s += (int)A[i*n+k] * (int)B[k*p+j]; C[i*p+j] = (s > 255) ? 255 : (s < 0 ? 0 : s); } -} -static void ref_i8_f64_f32(size_t m, size_t n, size_t p, const int8_t *A, const double *B, float *C) { - for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { float s = 0; for (size_t k = 0; k < n; k++) s += (float)A[i*n+k] * (float)B[k*p+j]; C[i*p+j] = s; } -} -static void ref_i8_f64_f64(size_t m, size_t n, size_t p, const int8_t *A, const double *B, double *C) { - for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { double s = 0; for (size_t k = 0; k < n; k++) s += (double)A[i*n+k] * B[k*p+j]; C[i*p+j] = s; } -} -static void ref_i8_f64_i8(size_t m, size_t n, size_t p, const int8_t *A, const double *B, int8_t *C) { - for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { int s = 0; for (size_t k = 0; k < n; k++) s += (int)A[i*n+k] * (int)B[k*p+j]; C[i*p+j] = (s > 127) ? 127 : (s < -128 ? -128 : s); } -} -static void ref_i8_f64_u8(size_t m, size_t n, size_t p, const int8_t *A, const double *B, uint8_t *C) { - for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { int s = 0; for (size_t k = 0; k < n; k++) s += (int)A[i*n+k] * (int)B[k*p+j]; C[i*p+j] = (s > 255) ? 255 : (s < 0 ? 0 : s); } -} -static void ref_i8_i8_f32(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, float *C) { - for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { float s = 0; for (size_t k = 0; k < n; k++) s += (float)A[i*n+k] * (float)B[k*p+j]; C[i*p+j] = s; } -} -static void ref_i8_i8_f64(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, double *C) { - for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { double s = 0; for (size_t k = 0; k < n; k++) s += (double)A[i*n+k] * (double)B[k*p+j]; C[i*p+j] = s; } -} -static void ref_i8_i8_i8(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, int8_t *C) { - for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { int s = 0; for (size_t k = 0; k < n; k++) s += (int)A[i*n+k] * (int)B[k*p+j]; C[i*p+j] = (s > 127) ? 127 : (s < -128 ? -128 : s); } -} -static void ref_i8_i8_u8(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, uint8_t *C) { - for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { int s = 0; for (size_t k = 0; k < n; k++) s += (int)A[i*n+k] * (int)B[k*p+j]; C[i*p+j] = (s > 255) ? 255 : (s < 0 ? 0 : s); } -} -static void ref_i8_u8_f32(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, float *C) { - for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { float s = 0; for (size_t k = 0; k < n; k++) s += (float)A[i*n+k] * (float)B[k*p+j]; C[i*p+j] = s; } -} -static void ref_i8_u8_f64(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, double *C) { - for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { double s = 0; for (size_t k = 0; k < n; k++) s += (double)A[i*n+k] * (double)B[k*p+j]; C[i*p+j] = s; } -} -static void ref_i8_u8_i8(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, int8_t *C) { - for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { int s = 0; for (size_t k = 0; k < n; k++) s += (int)A[i*n+k] * (int)B[k*p+j]; C[i*p+j] = (s > 127) ? 127 : (s < -128 ? -128 : s); } -} -static void ref_i8_u8_u8(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, uint8_t *C) { - for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { int s = 0; for (size_t k = 0; k < n; k++) s += (int)A[i*n+k] * (int)B[k*p+j]; C[i*p+j] = (s > 255) ? 255 : (s < 0 ? 0 : s); } -} -static void ref_u8_f32_f32(size_t m, size_t n, size_t p, const uint8_t *A, const float *B, float *C) { - for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { float s = 0; for (size_t k = 0; k < n; k++) s += (float)A[i*n+k] * B[k*p+j]; C[i*p+j] = s; } -} -static void ref_u8_f32_f64(size_t m, size_t n, size_t p, const uint8_t *A, const float *B, double *C) { - for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { double s = 0; for (size_t k = 0; k < n; k++) s += (double)A[i*n+k] * (double)B[k*p+j]; C[i*p+j] = s; } -} -static void ref_u8_f32_i8(size_t m, size_t n, size_t p, const uint8_t *A, const float *B, int8_t *C) { - for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { int s = 0; for (size_t k = 0; k < n; k++) s += (int)A[i*n+k] * (int)B[k*p+j]; C[i*p+j] = (s > 127) ? 127 : (s < -128 ? -128 : s); } -} -static void ref_u8_f32_u8(size_t m, size_t n, size_t p, const uint8_t *A, const float *B, uint8_t *C) { - for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { int s = 0; for (size_t k = 0; k < n; k++) s += (int)A[i*n+k] * (int)B[k*p+j]; C[i*p+j] = (s > 255) ? 255 : (s < 0 ? 0 : s); } -} -static void ref_u8_f64_f32(size_t m, size_t n, size_t p, const uint8_t *A, const double *B, float *C) { - for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { float s = 0; for (size_t k = 0; k < n; k++) s += (float)A[i*n+k] * (float)B[k*p+j]; C[i*p+j] = s; } -} -static void ref_u8_f64_f64(size_t m, size_t n, size_t p, const uint8_t *A, const double *B, double *C) { - for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { double s = 0; for (size_t k = 0; k < n; k++) s += (double)A[i*n+k] * B[k*p+j]; C[i*p+j] = s; } -} -static void ref_u8_f64_i8(size_t m, size_t n, size_t p, const uint8_t *A, const double *B, int8_t *C) { - for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { int s = 0; for (size_t k = 0; k < n; k++) s += (int)A[i*n+k] * (int)B[k*p+j]; C[i*p+j] = (s > 127) ? 127 : (s < -128 ? -128 : s); } -} -static void ref_u8_f64_u8(size_t m, size_t n, size_t p, const uint8_t *A, const double *B, uint8_t *C) { - for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { int s = 0; for (size_t k = 0; k < n; k++) s += (int)A[i*n+k] * (int)B[k*p+j]; C[i*p+j] = (s > 255) ? 255 : (s < 0 ? 0 : s); } -} -static void ref_u8_i8_f32(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, float *C) { - for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { float s = 0; for (size_t k = 0; k < n; k++) s += (float)A[i*n+k] * (float)B[k*p+j]; C[i*p+j] = s; } -} -static void ref_u8_i8_f64(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, double *C) { - for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { double s = 0; for (size_t k = 0; k < n; k++) s += (double)A[i*n+k] * (double)B[k*p+j]; C[i*p+j] = s; } -} -static void ref_u8_i8_i8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, int8_t *C) { - for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { int s = 0; for (size_t k = 0; k < n; k++) s += (int)A[i*n+k] * (int)B[k*p+j]; C[i*p+j] = (s > 127) ? 127 : (s < -128 ? -128 : s); } -} static void ref_u8_i8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, uint8_t *C) { - for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { int s = 0; for (size_t k = 0; k < n; k++) s += (int)A[i*n+k] * (int)B[k*p+j]; C[i*p+j] = (s > 255) ? 255 : (s < 0 ? 0 : s); } + for (size_t i = 0; i < m; i++) + for (size_t j = 0; j < p; j++) { + int sum = 0; + for (size_t k = 0; k < n; k++) sum += (int)A[i * n + k] * (int)B[k * p + j]; + if (sum > 255) sum = 255; + if (sum < 0) sum = 0; + C[i * p + j] = (uint8_t)sum; + } } -static void ref_u8_u8_f32(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, float *C) { - for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { float s = 0; for (size_t k = 0; k < n; k++) s += (float)A[i*n+k] * (float)B[k*p+j]; C[i*p+j] = s; } + +static MunitResult test_u8_i8_u8_small(const char *name, + int (*matmul_fn)(size_t, size_t, size_t, const uint8_t *, const int8_t *, + uint8_t *, double), + double epsilon) { + uint8_t A[] = {1, 2, 3, 4, 5, 6}; + int8_t B[] = {1, 0, 0, 1, 0, 0}; + uint8_t C[4], E[4]; + + ref_u8_i8_u8(2, 3, 2, A, B, E); + matmul_fn(2, 3, 2, A, B, C, 0.0); + + for (int i = 0; i < 4; i++) { + int d = E[i] - C[i]; + if (d < 0) d = -d; + if (d > (int)epsilon) return MUNIT_FAIL; + } + return MUNIT_OK; +} + +static MunitResult test_u8_i8_u8_medium(const char *name, + int (*matmul_fn)(size_t, size_t, size_t, const uint8_t *, const int8_t *, + uint8_t *, double), + double epsilon) { + const size_t m = 64, n = 64, p = 64; + uint8_t *A = malloc(m * n); + int8_t *B = malloc(n * p); + uint8_t *C = malloc(m * p); + uint8_t *E = malloc(m * p); + if (!A || !B || !C || !E) { + free(A); + free(B); + free(C); + free(E); + return MUNIT_SKIP; + } + + // Deterministic pseudo-random values + for (size_t i = 0; i < m * n; i++) A[i] = (uint8_t)((i * 7 + 13) % 251); + for (size_t i = 0; i < n * p; i++) B[i] = (int8_t)(((i * 11 + 17) % 211) - 105); + memset(C, 0, m * p); + memset(E, 0, m * p); + + ref_u8_i8_u8(m, n, p, A, B, E); + matmul_fn(m, n, p, A, B, C, 0.0); + + for (size_t i = 0; i < m * p; i++) { + int d = (int)E[i] - (int)C[i]; + if (d < 0) d = -d; + if (d > (int)epsilon) { + free(A); + free(B); + free(C); + free(E); + return MUNIT_FAIL; + } + } + + free(A); + free(B); + free(C); + free(E); + return MUNIT_OK; } -static void ref_u8_u8_f64(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, double *C) { - for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { double s = 0; for (size_t k = 0; k < n; k++) s += (double)A[i*n+k] * (double)B[k*p+j]; C[i*p+j] = s; } + +static MunitResult test_u8_i8_u8(const char *name, + int (*matmul_fn)(size_t, size_t, size_t, const uint8_t *, const int8_t *, uint8_t *, + double), + double epsilon) { + return test_u8_i8_u8_small(name, matmul_fn, epsilon); } -static void ref_u8_u8_i8(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, int8_t *C) { - for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { int s = 0; for (size_t k = 0; k < n; k++) s += (int)A[i*n+k] * (int)B[k*p+j]; C[i*p+j] = (s > 127) ? 127 : (s < -128 ? -128 : s); } + +static MunitResult test_scalar_u8_i8_u8(const MunitParameter *params, void *data) { + (void)params; + (void)data; + return test_u8_i8_u8("scalar", matmul_scalar_u8_i8_u8, 0); } -static void ref_u8_u8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, uint8_t *C) { - for (size_t i = 0; i < m; i++) for (size_t j = 0; j < p; j++) { int s = 0; for (size_t k = 0; k < n; k++) s += (int)A[i*n+k] * (int)B[k*p+j]; C[i*p+j] = (s > 255) ? 255 : (s < 0 ? 0 : s); } + +static MunitResult test_scalar_u8_i8_u8_medium(const MunitParameter *params, void *data) { + (void)params; + (void)data; + return test_u8_i8_u8_medium("scalar", matmul_scalar_u8_i8_u8, 0); } -#define TEST_FLT(ta, tb, tc, backend) \ -static MunitResult test_##backend##_##ta##_##tb##_##tc(const MunitParameter *params, void *data) { \ - (void)params; (void)data; \ - ta A[] = {1, 2, 3, 4, 5, 6}; tb B[] = {1, 0, 0, 1, 0, 0}; tc C[4], E[4]; \ - ref_##ta##_##tb##_##tc(2, 3, 2, A, B, E); \ - matmul_##backend##_##ta##_##tb##_##tc(2, 3, 2, A, B, C, 0.0); \ - for (int i = 0; i < 4; i++) { float d = E[i] - C[i]; if (d < 0) d = -d; if (d > 1e-5f) return MUNIT_FAIL; } \ - return MUNIT_OK; \ +#ifdef __AVX512VNNI__ +static MunitResult test_avx512vnni_u8_i8_u8(const MunitParameter *params, void *data) { + (void)params; + (void)data; + return test_u8_i8_u8("avx512vnni", matmul_avx512vnni_u8_i8_u8, 0); } -#define TEST_DBL(ta, tb, tc, backend) \ -static MunitResult test_##backend##_##ta##_##tb##_##tc(const MunitParameter *params, void *data) { \ - (void)params; (void)data; \ - ta A[] = {1, 2, 3, 4, 5, 6}; tb B[] = {1, 0, 0, 1, 0, 0}; tc C[4], E[4]; \ - ref_##ta##_##tb##_##tc(2, 3, 2, A, B, E); \ - matmul_##backend##_##ta##_##tb##_##tc(2, 3, 2, A, B, C, 0.0); \ - for (int i = 0; i < 4; i++) { double d = E[i] - C[i]; if (d < 0) d = -d; if (d > 1e-10) return MUNIT_FAIL; } \ - return MUNIT_OK; \ + +static MunitResult test_avx512vnni_u8_i8_u8_medium(const MunitParameter *params, void *data) { + (void)params; + (void)data; + return test_u8_i8_u8_medium("avx512vnni", matmul_avx512vnni_u8_i8_u8, 0); } -#define TEST_INT(ta, tb, tc, backend) \ -static MunitResult test_##backend##_##ta##_##tb##_##tc(const MunitParameter *params, void *data) { \ - (void)params; (void)data; \ - ta A[] = {1, 2, 3, 4, 5, 6}; tb B[] = {1, 0, 0, 1, 0, 0}; tc C[4], E[4]; \ - ref_##ta##_##tb##_##tc(2, 3, 2, A, B, E); \ - matmul_##backend##_##ta##_##tb##_##tc(2, 3, 2, A, B, C, 0.0); \ - for (int i = 0; i < 4; i++) if (E[i] != C[i]) return MUNIT_FAIL; \ - return MUNIT_OK; \ +#endif + +static MunitResult test_dispatched_u8_i8_u8(const MunitParameter *params, void *data) { + (void)params; + (void)data; + return test_u8_i8_u8("dispatched", matmul_u8_i8_u8, 0); } -TEST_FLT(f32, f32, f32, scalar) TEST_DBL(f32, f32, f64, scalar) TEST_INT(f32, f32, i8, scalar) TEST_INT(f32, f32, u8, scalar) -TEST_FLT(f32, f64, f32, scalar) TEST_DBL(f32, f64, f64, scalar) TEST_INT(f32, f64, i8, scalar) TEST_INT(f32, f64, u8, scalar) -TEST_FLT(f32, i8, f32, scalar) TEST_DBL(f32, i8, f64, scalar) TEST_INT(f32, i8, i8, scalar) TEST_INT(f32, i8, u8, scalar) -TEST_FLT(f32, u8, f32, scalar) TEST_DBL(f32, u8, f64, scalar) TEST_INT(f32, u8, i8, scalar) TEST_INT(f32, u8, u8, scalar) -TEST_FLT(f64, f32, f32, scalar) TEST_DBL(f64, f32, f64, scalar) TEST_INT(f64, f32, i8, scalar) TEST_INT(f64, f32, u8, scalar) -TEST_FLT(f64, f64, f32, scalar) TEST_DBL(f64, f64, f64, scalar) TEST_INT(f64, f64, i8, scalar) TEST_INT(f64, f64, u8, scalar) -TEST_FLT(f64, i8, f32, scalar) TEST_DBL(f64, i8, f64, scalar) TEST_INT(f64, i8, i8, scalar) TEST_INT(f64, i8, u8, scalar) -TEST_FLT(f64, u8, f32, scalar) TEST_DBL(f64, u8, f64, scalar) TEST_INT(f64, u8, i8, scalar) TEST_INT(f64, u8, u8, scalar) -TEST_FLT(i8, f32, f32, scalar) TEST_DBL(i8, f32, f64, scalar) TEST_INT(i8, f32, i8, scalar) TEST_INT(i8, f32, u8, scalar) -TEST_FLT(i8, f64, f32, scalar) TEST_DBL(i8, f64, f64, scalar) TEST_INT(i8, f64, i8, scalar) TEST_INT(i8, f64, u8, scalar) -TEST_FLT(i8, i8, f32, scalar) TEST_DBL(i8, i8, f64, scalar) TEST_INT(i8, i8, i8, scalar) TEST_INT(i8, i8, u8, scalar) -TEST_FLT(i8, u8, f32, scalar) TEST_DBL(i8, u8, f64, scalar) TEST_INT(i8, u8, i8, scalar) TEST_INT(i8, u8, u8, scalar) -TEST_FLT(u8, f32, f32, scalar) TEST_DBL(u8, f32, f64, scalar) TEST_INT(u8, f32, i8, scalar) TEST_INT(u8, f32, u8, scalar) -TEST_FLT(u8, f64, f32, scalar) TEST_DBL(u8, f64, f64, scalar) TEST_INT(u8, f64, i8, scalar) TEST_INT(u8, f64, u8, scalar) -TEST_FLT(u8, i8, f32, scalar) TEST_DBL(u8, i8, f64, scalar) TEST_INT(u8, i8, i8, scalar) TEST_INT(u8, i8, u8, scalar) -TEST_FLT(u8, u8, f32, scalar) TEST_DBL(u8, u8, f64, scalar) TEST_INT(u8, u8, i8, scalar) TEST_INT(u8, u8, u8, scalar) +static MunitResult test_dispatched_u8_i8_u8_medium(const MunitParameter *params, void *data) { + (void)params; + (void)data; + return test_u8_i8_u8_medium("dispatched", matmul_u8_i8_u8, 0); +} static MunitTest tests[] = { - {"/scalar-f32-f32-f32", test_scalar_f32_f32_f32, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, - {"/scalar-f32-f32-f64", test_scalar_f32_f32_f64, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, - {"/scalar-f32-f32-i8", test_scalar_f32_f32_i8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, - {"/scalar-f32-f32-u8", test_scalar_f32_f32_u8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, - {"/scalar-f32-f64-f32", test_scalar_f32_f64_f32, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, - {"/scalar-f32-f64-f64", test_scalar_f32_f64_f64, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, - {"/scalar-f32-f64-i8", test_scalar_f32_f64_i8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, - {"/scalar-f32-f64-u8", test_scalar_f32_f64_u8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, - {"/scalar-f32-i8-f32", test_scalar_f32_i8_f32, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, - {"/scalar-f32-i8-f64", test_scalar_f32_i8_f64, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, - {"/scalar-f32-i8-i8", test_scalar_f32_i8_i8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, - {"/scalar-f32-i8-u8", test_scalar_f32_i8_u8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, - {"/scalar-f32-u8-f32", test_scalar_f32_u8_f32, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, - {"/scalar-f32-u8-f64", test_scalar_f32_u8_f64, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, - {"/scalar-f32-u8-i8", test_scalar_f32_u8_i8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, - {"/scalar-f32-u8-u8", test_scalar_f32_u8_u8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, - {"/scalar-f64-f32-f32", test_scalar_f64_f32_f32, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, - {"/scalar-f64-f32-f64", test_scalar_f64_f32_f64, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, - {"/scalar-f64-f32-i8", test_scalar_f64_f32_i8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, - {"/scalar-f64-f32-u8", test_scalar_f64_f32_u8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, - {"/scalar-f64-f64-f32", test_scalar_f64_f64_f32, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, - {"/scalar-f64-f64-f64", test_scalar_f64_f64_f64, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, - {"/scalar-f64-f64-i8", test_scalar_f64_f64_i8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, - {"/scalar-f64-f64-u8", test_scalar_f64_f64_u8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, - {"/scalar-f64-i8-f32", test_scalar_f64_i8_f32, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, - {"/scalar-f64-i8-f64", test_scalar_f64_i8_f64, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, - {"/scalar-f64-i8-i8", test_scalar_f64_i8_i8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, - {"/scalar-f64-i8-u8", test_scalar_f64_i8_u8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, - {"/scalar-f64-u8-f32", test_scalar_f64_u8_f32, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, - {"/scalar-f64-u8-f64", test_scalar_f64_u8_f64, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, - {"/scalar-f64-u8-i8", test_scalar_f64_u8_i8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, - {"/scalar-f64-u8-u8", test_scalar_f64_u8_u8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, - {"/scalar-i8-f32-f32", test_scalar_i8_f32_f32, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, - {"/scalar-i8-f32-f64", test_scalar_i8_f32_f64, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, - {"/scalar-i8-f32-i8", test_scalar_i8_f32_i8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, - {"/scalar-i8-f32-u8", test_scalar_i8_f32_u8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, - {"/scalar-i8-f64-f32", test_scalar_i8_f64_f32, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, - {"/scalar-i8-f64-f64", test_scalar_i8_f64_f64, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, - {"/scalar-i8-f64-i8", test_scalar_i8_f64_i8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, - {"/scalar-i8-f64-u8", test_scalar_i8_f64_u8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, - {"/scalar-i8-i8-f32", test_scalar_i8_i8_f32, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, - {"/scalar-i8-i8-f64", test_scalar_i8_i8_f64, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, - {"/scalar-i8-i8-i8", test_scalar_i8_i8_i8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, - {"/scalar-i8-i8-u8", test_scalar_i8_i8_u8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, - {"/scalar-i8-u8-f32", test_scalar_i8_u8_f32, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, - {"/scalar-i8-u8-f64", test_scalar_i8_u8_f64, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, - {"/scalar-i8-u8-i8", test_scalar_i8_u8_i8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, - {"/scalar-i8-u8-u8", test_scalar_i8_u8_u8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, - {"/scalar-u8-f32-f32", test_scalar_u8_f32_f32, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, - {"/scalar-u8-f32-f64", test_scalar_u8_f32_f64, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, - {"/scalar-u8-f32-i8", test_scalar_u8_f32_i8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, - {"/scalar-u8-f32-u8", test_scalar_u8_f32_u8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, - {"/scalar-u8-f64-f32", test_scalar_u8_f64_f32, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, - {"/scalar-u8-f64-f64", test_scalar_u8_f64_f64, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, - {"/scalar-u8-f64-i8", test_scalar_u8_f64_i8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, - {"/scalar-u8-f64-u8", test_scalar_u8_f64_u8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, - {"/scalar-u8-i8-f32", test_scalar_u8_i8_f32, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, - {"/scalar-u8-i8-f64", test_scalar_u8_i8_f64, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, - {"/scalar-u8-i8-i8", test_scalar_u8_i8_i8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, {"/scalar-u8-i8-u8", test_scalar_u8_i8_u8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, - {"/scalar-u8-u8-f32", test_scalar_u8_u8_f32, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, - {"/scalar-u8-u8-f64", test_scalar_u8_u8_f64, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, - {"/scalar-u8-u8-i8", test_scalar_u8_u8_i8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, - {"/scalar-u8-u8-u8", test_scalar_u8_u8_u8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, - {NULL, NULL, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL} -}; + {"/scalar-u8-i8-u8-medium", test_scalar_u8_i8_u8_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, +#ifdef __AVX512VNNI__ + {"/avx512vnni-u8-i8-u8", test_avx512vnni_u8_i8_u8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, + {"/avx512vnni-u8-i8-u8-medium", test_avx512vnni_u8_i8_u8_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, +#endif + {"/dispatched-u8-i8-u8", test_dispatched_u8_i8_u8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, + {"/dispatched-u8-i8-u8-medium", test_dispatched_u8_i8_u8_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, + {NULL, NULL, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}}; -static const MunitSuite suite = { "/matmul", tests, NULL, 1, MUNIT_SUITE_OPTION_NONE }; +static const MunitSuite suite = {"/matmul", tests, NULL, 1, MUNIT_SUITE_OPTION_NONE}; int main(int argc, char *argv[MUNIT_ARRAY_PARAM(argc)]) { - return munit_suite_main(&suite, NULL, argc, argv); + return munit_suite_main(&suite, NULL, argc, argv); } diff --git a/test/test_matmul_simd.h b/test/test_matmul_simd.h @@ -41,6 +41,10 @@ #include <stddef.h> #include <stdint.h> +#ifdef __cplusplus +extern "C" { +#endif + #define MATMUL_SIMD_VERSION "1.0.0" #define MATMUL_FLAG_SCALAR (1 << 0) @@ -54,128 +58,16 @@ typedef uint32_t matmul_feature_t; matmul_feature_t matmul_get_feature(void); const char *matmul_get_feature_name(matmul_feature_t feat); -int matmul_scalar_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const float *B, float *C, double scale); -int matmul_scalar_f32_f32_f64(size_t m, size_t n, size_t p, const float *A, const float *B, double *C, double scale); -int matmul_scalar_f32_f32_i8(size_t m, size_t n, size_t p, const float *A, const float *B, int8_t *C, double scale); -int matmul_scalar_f32_f32_u8(size_t m, size_t n, size_t p, const float *A, const float *B, uint8_t *C, double scale); -int matmul_scalar_f32_f64_f32(size_t m, size_t n, size_t p, const float *A, const double *B, float *C, double scale); -int matmul_scalar_f32_f64_f64(size_t m, size_t n, size_t p, const float *A, const double *B, double *C, double scale); -int matmul_scalar_f32_f64_i8(size_t m, size_t n, size_t p, const float *A, const double *B, int8_t *C, double scale); -int matmul_scalar_f32_f64_u8(size_t m, size_t n, size_t p, const float *A, const double *B, uint8_t *C, double scale); -int matmul_scalar_f32_i8_f32(size_t m, size_t n, size_t p, const float *A, const int8_t *B, float *C, double scale); -int matmul_scalar_f32_i8_f64(size_t m, size_t n, size_t p, const float *A, const int8_t *B, double *C, double scale); -int matmul_scalar_f32_i8_i8(size_t m, size_t n, size_t p, const float *A, const int8_t *B, int8_t *C, double scale); -int matmul_scalar_f32_i8_u8(size_t m, size_t n, size_t p, const float *A, const int8_t *B, uint8_t *C, double scale); -int matmul_scalar_f32_u8_f32(size_t m, size_t n, size_t p, const float *A, const uint8_t *B, float *C, double scale); -int matmul_scalar_f32_u8_f64(size_t m, size_t n, size_t p, const float *A, const uint8_t *B, double *C, double scale); -int matmul_scalar_f32_u8_i8(size_t m, size_t n, size_t p, const float *A, const uint8_t *B, int8_t *C, double scale); -int matmul_scalar_f32_u8_u8(size_t m, size_t n, size_t p, const float *A, const uint8_t *B, uint8_t *C, double scale); -int matmul_scalar_f64_f32_f32(size_t m, size_t n, size_t p, const double *A, const float *B, float *C, double scale); -int matmul_scalar_f64_f32_f64(size_t m, size_t n, size_t p, const double *A, const float *B, double *C, double scale); -int matmul_scalar_f64_f32_i8(size_t m, size_t n, size_t p, const double *A, const float *B, int8_t *C, double scale); -int matmul_scalar_f64_f32_u8(size_t m, size_t n, size_t p, const double *A, const float *B, uint8_t *C, double scale); -int matmul_scalar_f64_f64_f32(size_t m, size_t n, size_t p, const double *A, const double *B, float *C, double scale); -int matmul_scalar_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, const double *B, double *C, double scale); -int matmul_scalar_f64_f64_i8(size_t m, size_t n, size_t p, const double *A, const double *B, int8_t *C, double scale); -int matmul_scalar_f64_f64_u8(size_t m, size_t n, size_t p, const double *A, const double *B, uint8_t *C, double scale); -int matmul_scalar_f64_i8_f32(size_t m, size_t n, size_t p, const double *A, const int8_t *B, float *C, double scale); -int matmul_scalar_f64_i8_f64(size_t m, size_t n, size_t p, const double *A, const int8_t *B, double *C, double scale); -int matmul_scalar_f64_i8_i8(size_t m, size_t n, size_t p, const double *A, const int8_t *B, int8_t *C, double scale); -int matmul_scalar_f64_i8_u8(size_t m, size_t n, size_t p, const double *A, const int8_t *B, uint8_t *C, double scale); -int matmul_scalar_f64_u8_f32(size_t m, size_t n, size_t p, const double *A, const uint8_t *B, float *C, double scale); -int matmul_scalar_f64_u8_f64(size_t m, size_t n, size_t p, const double *A, const uint8_t *B, double *C, double scale); -int matmul_scalar_f64_u8_i8(size_t m, size_t n, size_t p, const double *A, const uint8_t *B, int8_t *C, double scale); -int matmul_scalar_f64_u8_u8(size_t m, size_t n, size_t p, const double *A, const uint8_t *B, uint8_t *C, double scale); -int matmul_scalar_i8_f32_f32(size_t m, size_t n, size_t p, const int8_t *A, const float *B, float *C, double scale); -int matmul_scalar_i8_f32_f64(size_t m, size_t n, size_t p, const int8_t *A, const float *B, double *C, double scale); -int matmul_scalar_i8_f32_i8(size_t m, size_t n, size_t p, const int8_t *A, const float *B, int8_t *C, double scale); -int matmul_scalar_i8_f32_u8(size_t m, size_t n, size_t p, const int8_t *A, const float *B, uint8_t *C, double scale); -int matmul_scalar_i8_f64_f32(size_t m, size_t n, size_t p, const int8_t *A, const double *B, float *C, double scale); -int matmul_scalar_i8_f64_f64(size_t m, size_t n, size_t p, const int8_t *A, const double *B, double *C, double scale); -int matmul_scalar_i8_f64_i8(size_t m, size_t n, size_t p, const int8_t *A, const double *B, int8_t *C, double scale); -int matmul_scalar_i8_f64_u8(size_t m, size_t n, size_t p, const int8_t *A, const double *B, uint8_t *C, double scale); -int matmul_scalar_i8_i8_f32(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, float *C, double scale); -int matmul_scalar_i8_i8_f64(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, double *C, double scale); -int matmul_scalar_i8_i8_i8(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, int8_t *C, double scale); -int matmul_scalar_i8_i8_u8(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, uint8_t *C, double scale); -int matmul_scalar_i8_u8_f32(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, float *C, double scale); -int matmul_scalar_i8_u8_f64(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, double *C, double scale); -int matmul_scalar_i8_u8_i8(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, int8_t *C, double scale); -int matmul_scalar_i8_u8_u8(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, uint8_t *C, double scale); -int matmul_scalar_u8_f32_f32(size_t m, size_t n, size_t p, const uint8_t *A, const float *B, float *C, double scale); -int matmul_scalar_u8_f32_f64(size_t m, size_t n, size_t p, const uint8_t *A, const float *B, double *C, double scale); -int matmul_scalar_u8_f32_i8(size_t m, size_t n, size_t p, const uint8_t *A, const float *B, int8_t *C, double scale); -int matmul_scalar_u8_f32_u8(size_t m, size_t n, size_t p, const uint8_t *A, const float *B, uint8_t *C, double scale); -int matmul_scalar_u8_f64_f32(size_t m, size_t n, size_t p, const uint8_t *A, const double *B, float *C, double scale); -int matmul_scalar_u8_f64_f64(size_t m, size_t n, size_t p, const uint8_t *A, const double *B, double *C, double scale); -int matmul_scalar_u8_f64_i8(size_t m, size_t n, size_t p, const uint8_t *A, const double *B, int8_t *C, double scale); -int matmul_scalar_u8_f64_u8(size_t m, size_t n, size_t p, const uint8_t *A, const double *B, uint8_t *C, double scale); -int matmul_scalar_u8_i8_f32(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, float *C, double scale); -int matmul_scalar_u8_i8_f64(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, double *C, double scale); -int matmul_scalar_u8_i8_i8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, int8_t *C, double scale); int matmul_scalar_u8_i8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, uint8_t *C, double scale); -int matmul_scalar_u8_u8_f32(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, float *C, double scale); -int matmul_scalar_u8_u8_f64(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, double *C, double scale); -int matmul_scalar_u8_u8_i8(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, int8_t *C, double scale); -int matmul_scalar_u8_u8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, uint8_t *C, double scale); - -#ifdef __cplusplus -extern "C" { -#endif #ifdef __AVX2__ #include <immintrin.h> -int matmul_avx2_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const float *B, float *C, double scale); -int matmul_avx2_f32_f32_f64(size_t m, size_t n, size_t p, const float *A, const float *B, double *C, double scale); -int matmul_avx2_f32_f64_f32(size_t m, size_t n, size_t p, const float *A, const double *B, float *C, double scale); -int matmul_avx2_f32_f64_f64(size_t m, size_t n, size_t p, const float *A, const double *B, double *C, double scale); -int matmul_avx2_f64_f32_f32(size_t m, size_t n, size_t p, const double *A, const float *B, float *C, double scale); -int matmul_avx2_f64_f32_f64(size_t m, size_t n, size_t p, const double *A, const float *B, double *C, double scale); -int matmul_avx2_f64_f64_f32(size_t m, size_t n, size_t p, const double *A, const double *B, float *C, double scale); -int matmul_avx2_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, const double *B, double *C, double scale); -int matmul_avx2_i8_i8_f32(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, float *C, double scale); -int matmul_avx2_i8_i8_f64(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, double *C, double scale); -int matmul_avx2_i8_i8_i8(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, int8_t *C, double scale); -int matmul_avx2_i8_i8_u8(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, uint8_t *C, double scale); -int matmul_avx2_i8_u8_f32(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, float *C, double scale); -int matmul_avx2_i8_u8_f64(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, double *C, double scale); -int matmul_avx2_i8_u8_i8(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, int8_t *C, double scale); -int matmul_avx2_i8_u8_u8(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, uint8_t *C, double scale); -int matmul_avx2_u8_i8_f32(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, float *C, double scale); -int matmul_avx2_u8_i8_f64(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, double *C, double scale); -int matmul_avx2_u8_i8_i8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, int8_t *C, double scale); int matmul_avx2_u8_i8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, uint8_t *C, double scale); -int matmul_avx2_u8_u8_f32(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, float *C, double scale); -int matmul_avx2_u8_u8_f64(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, double *C, double scale); -int matmul_avx2_u8_u8_i8(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, int8_t *C, double scale); -int matmul_avx2_u8_u8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, uint8_t *C, double scale); #endif -#ifdef __AVX512F__ -int matmul_avx512_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const float *B, float *C, double scale); -int matmul_avx512_f32_f32_f64(size_t m, size_t n, size_t p, const float *A, const float *B, double *C, double scale); -int matmul_avx512_f32_f64_f32(size_t m, size_t n, size_t p, const float *A, const double *B, float *C, double scale); -int matmul_avx512_f32_f64_f64(size_t m, size_t n, size_t p, const float *A, const double *B, double *C, double scale); -int matmul_avx512_f64_f32_f32(size_t m, size_t n, size_t p, const double *A, const float *B, float *C, double scale); -int matmul_avx512_f64_f32_f64(size_t m, size_t n, size_t p, const double *A, const float *B, double *C, double scale); -int matmul_avx512_f64_f64_f32(size_t m, size_t n, size_t p, const double *A, const double *B, float *C, double scale); -int matmul_avx512_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, const double *B, double *C, double scale); -int matmul_avx512_i8_i8_f32(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, float *C, double scale); -int matmul_avx512_i8_i8_f64(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, double *C, double scale); -int matmul_avx512_i8_i8_i8(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, int8_t *C, double scale); -int matmul_avx512_i8_i8_u8(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, uint8_t *C, double scale); -int matmul_avx512_i8_u8_f32(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, float *C, double scale); -int matmul_avx512_i8_u8_f64(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, double *C, double scale); -int matmul_avx512_i8_u8_i8(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, int8_t *C, double scale); -int matmul_avx512_i8_u8_u8(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, uint8_t *C, double scale); -int matmul_avx512_u8_i8_f32(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, float *C, double scale); -int matmul_avx512_u8_i8_f64(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, double *C, double scale); -int matmul_avx512_u8_i8_i8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, int8_t *C, double scale); -int matmul_avx512_u8_i8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, uint8_t *C, double scale); -int matmul_avx512_u8_u8_f32(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, float *C, double scale); -int matmul_avx512_u8_u8_f64(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, double *C, double scale); -int matmul_avx512_u8_u8_i8(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, int8_t *C, double scale); -int matmul_avx512_u8_u8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, uint8_t *C, double scale); +#ifdef __AVX512VNNI__ +int matmul_avx512vnni_u8_i8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, uint8_t *C, + double scale); #endif #ifdef __cplusplus