matmul.c

Matrix multiplication helper library
git clone git://git.finwo.net/lib/matmul.c
Log | Files | Refs | README | LICENSE

commit 5fca7db716571e8b1188f10d3cf38b0fa89f33b6
parent adb428832e725526e51335694a2917f4d0cc3091
Author: finwo <finwo@pm.me>
Date:   Thu, 16 Apr 2026 23:30:08 +0200

Re-add vnni accel

Diffstat:
Msrc/matmul.c | 2270+++++++++++++++++++++++++++++++++++++++++++++++++++++++++----------------------
Msrc/matmul.h | 166++++++++++++++++++++++++++++++++++++++++++++++++-------------------------------
2 files changed, 1747 insertions(+), 689 deletions(-)

diff --git a/src/matmul.c b/src/matmul.c @@ -55,7 +55,13 @@ static void init_feature(void) { if (__builtin_cpu_supports("avx2")) g_feature |= MATMUL_FLAG_AVX2; #endif #ifdef __AVX512F__ - if (__builtin_cpu_supports("avx512f")) g_feature |= MATMUL_FLAG_AVX512; + if (__builtin_cpu_supports("avx512f")) { + g_feature |= MATMUL_FLAG_AVX512; + if (__builtin_cpu_supports("avx512vnni")) g_feature |= MATMUL_FLAG_AVX512_VNNI; + } +#endif +#ifdef __AVX2__ + if (__builtin_cpu_supports("avx2") && __builtin_cpu_supports("avxvnni")) g_feature |= MATMUL_FLAG_AVXVNNI; #endif } @@ -1433,952 +1439,1966 @@ int matmul_avx512_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, con } #endif -int matmul_scalar_u8_u8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, uint8_t *C, double scale) { +#ifdef __AVX2__ +static inline int32_t reduce_add_i32x8(__m256i v) { + __m128i low = _mm256_extracti128_si256(v, 0); + __m128i high = _mm256_extracti128_si256(v, 1); + __m128i sum = _mm_add_epi32(low, high); + sum = _mm_hadd_epi32(sum, sum); + sum = _mm_hadd_epi32(sum, sum); + return _mm_cvtsi128_si32(sum); +} + +int matmul_avx2_i8_i8_f32(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, float *C, double scale) { #ifdef _OPENMP #pragma omp parallel for schedule(static) #endif for (size_t i = 0; i < m; i++) { for (size_t j = 0; j < p; j++) { - int sum = 0; - for (size_t k = 0; k < n; k++) { - sum += (int)A[i * n + k] * (int)B[k * p + j]; + __m256i sum_vec = _mm256_setzero_si256(); + size_t k = 0; + for (; k + 31 < n; k += 32) { + __m256i a_lo = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k])); + __m256i a_hi = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16])); + __m256i b_lo = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j])); + __m256i b_hi = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16])); + __m256i mul_lo = _mm256_madd_epi16(a_lo, b_lo); + __m256i mul_hi = _mm256_madd_epi16(a_hi, b_hi); + sum_vec = _mm256_add_epi32(sum_vec, mul_lo); + sum_vec = _mm256_add_epi32(sum_vec, mul_hi); } - if (scale != 0 && scale != 1) sum = (int)(sum / scale); - if (sum > 255) - sum = 255; - else if (sum < 0) - sum = 0; - C[i * p + j] = (uint8_t)sum; + int s = reduce_add_i32x8(sum_vec); + for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j]; + if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale); + C[i * p + j] = (float)s; } } return 0; } -static int _matmul_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const float *B, float *C, double scale) { - static int (*impl)(size_t, size_t, size_t, const float *, const float *, float *, double) = NULL; - static int initialized = 0; - if (!initialized) { - matmul_feature_t feat = matmul_get_feature(); -#ifdef __AVX512F__ - if (feat & MATMUL_FLAG_AVX512) - impl = matmul_avx512_f32_f32_f32; - else -#endif -#ifdef __AVX2__ - if (feat & MATMUL_FLAG_AVX2) - impl = matmul_avx2_f32_f32_f32; - else +int matmul_avx2_i8_i8_f64(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, double *C, double scale) { +#ifdef _OPENMP +#pragma omp parallel for schedule(static) #endif - impl = matmul_scalar_f32_f32_f32; - initialized = 1; + for (size_t i = 0; i < m; i++) { + for (size_t j = 0; j < p; j++) { + __m256i sum_vec = _mm256_setzero_si256(); + size_t k = 0; + for (; k + 31 < n; k += 32) { + __m256i a_lo = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k])); + __m256i a_hi = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16])); + __m256i b_lo = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j])); + __m256i b_hi = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16])); + __m256i mul_lo = _mm256_madd_epi16(a_lo, b_lo); + __m256i mul_hi = _mm256_madd_epi16(a_hi, b_hi); + sum_vec = _mm256_add_epi32(sum_vec, mul_lo); + sum_vec = _mm256_add_epi32(sum_vec, mul_hi); + } + int32_t sum[8]; + _mm256_storeu_si256((__m256i *)sum, sum_vec); + int s = reduce_add_i32x8(sum_vec); + for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j]; + if (scale != 0 && scale != 1) s = (int)((double)s / scale); + C[i * p + j] = (double)s; + } } - return impl(m, n, p, A, B, C, scale); + return 0; } -static int _matmul_f32_f32_f64(size_t m, size_t n, size_t p, const float *A, const float *B, double *C, double scale) { - static int (*impl)(size_t, size_t, size_t, const float *, const float *, double *, double) = NULL; - static int initialized = 0; - if (!initialized) { - matmul_feature_t feat = matmul_get_feature(); -#ifdef __AVX512F__ - if (feat & MATMUL_FLAG_AVX512) - impl = matmul_avx512_f32_f32_f64; - else -#endif -#ifdef __AVX2__ - if (feat & MATMUL_FLAG_AVX2) - impl = matmul_avx2_f32_f32_f64; - else +int matmul_avx2_i8_i8_i8(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, int8_t *C, double scale) { +#ifdef _OPENMP +#pragma omp parallel for schedule(static) #endif - impl = matmul_scalar_f32_f32_f64; - initialized = 1; + for (size_t i = 0; i < m; i++) { + for (size_t j = 0; j < p; j++) { + __m256i sum_vec = _mm256_setzero_si256(); + size_t k = 0; + for (; k + 31 < n; k += 32) { + __m256i a_lo = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k])); + __m256i a_hi = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16])); + __m256i b_lo = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j])); + __m256i b_hi = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16])); + __m256i mul_lo = _mm256_madd_epi16(a_lo, b_lo); + __m256i mul_hi = _mm256_madd_epi16(a_hi, b_hi); + sum_vec = _mm256_add_epi32(sum_vec, mul_lo); + sum_vec = _mm256_add_epi32(sum_vec, mul_hi); + } + int32_t sum[8]; + _mm256_storeu_si256((__m256i *)sum, sum_vec); + int s = reduce_add_i32x8(sum_vec); + for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j]; + if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale); + if (s > 127) + s = 127; + else if (s < -128) + s = -128; + C[i * p + j] = (int8_t)s; + } } - return impl(m, n, p, A, B, C, scale); + return 0; } -static int _matmul_f32_f32_i8(size_t m, size_t n, size_t p, const float *A, const float *B, int8_t *C, double scale) { - static int (*impl)(size_t, size_t, size_t, const float *, const float *, int8_t *, double) = NULL; - static int initialized = 0; - if (!initialized) { - impl = matmul_scalar_f32_f32_i8; - initialized = 1; +int matmul_avx2_i8_i8_u8(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, uint8_t *C, double scale) { +#ifdef _OPENMP +#pragma omp parallel for schedule(static) +#endif + for (size_t i = 0; i < m; i++) { + for (size_t j = 0; j < p; j++) { + __m256i sum_vec = _mm256_setzero_si256(); + size_t k = 0; + for (; k + 31 < n; k += 32) { + __m256i a_lo = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k])); + __m256i a_hi = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16])); + __m256i b_lo = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j])); + __m256i b_hi = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16])); + __m256i mul_lo = _mm256_madd_epi16(a_lo, b_lo); + __m256i mul_hi = _mm256_madd_epi16(a_hi, b_hi); + sum_vec = _mm256_add_epi32(sum_vec, mul_lo); + sum_vec = _mm256_add_epi32(sum_vec, mul_hi); + } + int32_t sum[8]; + _mm256_storeu_si256((__m256i *)sum, sum_vec); + int s = reduce_add_i32x8(sum_vec); + for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j]; + if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale); + if (s > 255) + s = 255; + else if (s < 0) + s = 0; + C[i * p + j] = (uint8_t)s; + } } - return impl(m, n, p, A, B, C, scale); + return 0; } -static int _matmul_f32_f32_u8(size_t m, size_t n, size_t p, const float *A, const float *B, uint8_t *C, double scale) { - static int (*impl)(size_t, size_t, size_t, const float *, const float *, uint8_t *, double) = NULL; - static int initialized = 0; - if (!initialized) { - impl = matmul_scalar_f32_f32_u8; - initialized = 1; +int matmul_avx2_i8_u8_f32(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, float *C, double scale) { +#ifdef _OPENMP +#pragma omp parallel for schedule(static) +#endif + for (size_t i = 0; i < m; i++) { + for (size_t j = 0; j < p; j++) { + __m256i sum_vec = _mm256_setzero_si256(); + size_t k = 0; + for (; k + 31 < n; k += 32) { + __m256i a_lo = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k])); + __m256i a_hi = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16])); + __m256i b_lo = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j])); + __m256i b_hi = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16])); + __m256i mul_lo = _mm256_madd_epi16(a_lo, b_lo); + __m256i mul_hi = _mm256_madd_epi16(a_hi, b_hi); + sum_vec = _mm256_add_epi32(sum_vec, mul_lo); + sum_vec = _mm256_add_epi32(sum_vec, mul_hi); + } + int s = reduce_add_i32x8(sum_vec); + for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j]; + if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale); + C[i * p + j] = (float)s; + } } - return impl(m, n, p, A, B, C, scale); + return 0; } -static int _matmul_f32_f64_f32(size_t m, size_t n, size_t p, const float *A, const double *B, float *C, double scale) { - static int (*impl)(size_t, size_t, size_t, const float *, const double *, float *, double) = NULL; - static int initialized = 0; - if (!initialized) { - matmul_feature_t feat = matmul_get_feature(); -#ifdef __AVX512F__ - if (feat & MATMUL_FLAG_AVX512) - impl = matmul_avx512_f32_f64_f32; - else -#endif -#ifdef __AVX2__ - if (feat & MATMUL_FLAG_AVX2) - impl = matmul_avx2_f32_f64_f32; - else +int matmul_avx2_i8_u8_f64(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, double *C, double scale) { +#ifdef _OPENMP +#pragma omp parallel for schedule(static) #endif - impl = matmul_scalar_f32_f64_f32; - initialized = 1; + for (size_t i = 0; i < m; i++) { + for (size_t j = 0; j < p; j++) { + __m256i sum_vec = _mm256_setzero_si256(); + size_t k = 0; + for (; k + 31 < n; k += 32) { + __m256i a_lo = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k])); + __m256i a_hi = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16])); + __m256i b_lo = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j])); + __m256i b_hi = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16])); + __m256i mul_lo = _mm256_madd_epi16(a_lo, b_lo); + __m256i mul_hi = _mm256_madd_epi16(a_hi, b_hi); + sum_vec = _mm256_add_epi32(sum_vec, mul_lo); + sum_vec = _mm256_add_epi32(sum_vec, mul_hi); + } + int s = reduce_add_i32x8(sum_vec); + for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j]; + if (scale != 0 && scale != 1) s = (int)((double)s / scale); + C[i * p + j] = (double)s; + } } - return impl(m, n, p, A, B, C, scale); + return 0; } -static int _matmul_f32_f64_f64(size_t m, size_t n, size_t p, const float *A, const double *B, double *C, double scale) { - static int (*impl)(size_t, size_t, size_t, const float *, const double *, double *, double) = NULL; - static int initialized = 0; - if (!initialized) { - matmul_feature_t feat = matmul_get_feature(); -#ifdef __AVX512F__ - if (feat & MATMUL_FLAG_AVX512) - impl = matmul_avx512_f32_f64_f64; - else -#endif -#ifdef __AVX2__ - if (feat & MATMUL_FLAG_AVX2) - impl = matmul_avx2_f32_f64_f64; - else +int matmul_avx2_i8_u8_i8(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, int8_t *C, double scale) { +#ifdef _OPENMP +#pragma omp parallel for schedule(static) #endif - impl = matmul_scalar_f32_f64_f64; - initialized = 1; + for (size_t i = 0; i < m; i++) { + for (size_t j = 0; j < p; j++) { + __m256i sum_vec = _mm256_setzero_si256(); + size_t k = 0; + for (; k + 31 < n; k += 32) { + __m256i a_lo = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k])); + __m256i a_hi = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16])); + __m256i b_lo = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j])); + __m256i b_hi = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16])); + __m256i mul_lo = _mm256_madd_epi16(a_lo, b_lo); + __m256i mul_hi = _mm256_madd_epi16(a_hi, b_hi); + sum_vec = _mm256_add_epi32(sum_vec, mul_lo); + sum_vec = _mm256_add_epi32(sum_vec, mul_hi); + } + int s = reduce_add_i32x8(sum_vec); + for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j]; + if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale); + if (s > 127) + s = 127; + else if (s < -128) + s = -128; + C[i * p + j] = (int8_t)s; + } } - return impl(m, n, p, A, B, C, scale); + return 0; } -static int _matmul_f32_f64_i8(size_t m, size_t n, size_t p, const float *A, const double *B, int8_t *C, double scale) { - static int (*impl)(size_t, size_t, size_t, const float *, const double *, int8_t *, double) = NULL; - static int initialized = 0; - if (!initialized) { - impl = matmul_scalar_f32_f64_i8; - initialized = 1; +int matmul_avx2_i8_u8_u8(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, uint8_t *C, double scale) { +#ifdef _OPENMP +#pragma omp parallel for schedule(static) +#endif + for (size_t i = 0; i < m; i++) { + for (size_t j = 0; j < p; j++) { + __m256i sum_vec = _mm256_setzero_si256(); + size_t k = 0; + for (; k + 31 < n; k += 32) { + __m256i a_lo = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k])); + __m256i a_hi = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16])); + __m256i b_lo = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j])); + __m256i b_hi = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16])); + __m256i mul_lo = _mm256_madd_epi16(a_lo, b_lo); + __m256i mul_hi = _mm256_madd_epi16(a_hi, b_hi); + sum_vec = _mm256_add_epi32(sum_vec, mul_lo); + sum_vec = _mm256_add_epi32(sum_vec, mul_hi); + } + int s = reduce_add_i32x8(sum_vec); + for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j]; + if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale); + if (s > 255) + s = 255; + else if (s < 0) + s = 0; + C[i * p + j] = (uint8_t)s; + } } - return impl(m, n, p, A, B, C, scale); + return 0; } -static int _matmul_f32_f64_u8(size_t m, size_t n, size_t p, const float *A, const double *B, uint8_t *C, double scale) { - static int (*impl)(size_t, size_t, size_t, const float *, const double *, uint8_t *, double) = NULL; - static int initialized = 0; - if (!initialized) { - impl = matmul_scalar_f32_f64_u8; - initialized = 1; +int matmul_avx2_u8_i8_f32(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, float *C, double scale) { +#ifdef _OPENMP +#pragma omp parallel for schedule(static) +#endif + for (size_t i = 0; i < m; i++) { + for (size_t j = 0; j < p; j++) { + __m256i sum_vec = _mm256_setzero_si256(); + size_t k = 0; + for (; k + 31 < n; k += 32) { + __m256i a_lo = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k])); + __m256i a_hi = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16])); + __m256i b_lo = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j])); + __m256i b_hi = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16])); + __m256i mul_lo = _mm256_madd_epi16(a_lo, b_lo); + __m256i mul_hi = _mm256_madd_epi16(a_hi, b_hi); + sum_vec = _mm256_add_epi32(sum_vec, mul_lo); + sum_vec = _mm256_add_epi32(sum_vec, mul_hi); + } + int s = reduce_add_i32x8(sum_vec); + for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j]; + if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale); + C[i * p + j] = (float)s; + } } - return impl(m, n, p, A, B, C, scale); + return 0; } -static int _matmul_f32_i8_f32(size_t m, size_t n, size_t p, const float *A, const int8_t *B, float *C, double scale) { - static int (*impl)(size_t, size_t, size_t, const float *, const int8_t *, float *, double) = NULL; - static int initialized = 0; - if (!initialized) { - impl = matmul_scalar_f32_i8_f32; - initialized = 1; +int matmul_avx2_u8_i8_f64(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, double *C, double scale) { +#ifdef _OPENMP +#pragma omp parallel for schedule(static) +#endif + for (size_t i = 0; i < m; i++) { + for (size_t j = 0; j < p; j++) { + __m256i sum_vec = _mm256_setzero_si256(); + size_t k = 0; + for (; k + 31 < n; k += 32) { + __m256i a_lo = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k])); + __m256i a_hi = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16])); + __m256i b_lo = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j])); + __m256i b_hi = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16])); + __m256i mul_lo = _mm256_madd_epi16(a_lo, b_lo); + __m256i mul_hi = _mm256_madd_epi16(a_hi, b_hi); + sum_vec = _mm256_add_epi32(sum_vec, mul_lo); + sum_vec = _mm256_add_epi32(sum_vec, mul_hi); + } + int s = reduce_add_i32x8(sum_vec); + for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j]; + if (scale != 0 && scale != 1) s = (int)((double)s / scale); + C[i * p + j] = (double)s; + } } - return impl(m, n, p, A, B, C, scale); + return 0; } -static int _matmul_f32_i8_f64(size_t m, size_t n, size_t p, const float *A, const int8_t *B, double *C, double scale) { - static int (*impl)(size_t, size_t, size_t, const float *, const int8_t *, double *, double) = NULL; - static int initialized = 0; - if (!initialized) { - impl = matmul_scalar_f32_i8_f64; - initialized = 1; +int matmul_avx2_u8_i8_i8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, int8_t *C, double scale) { +#ifdef _OPENMP +#pragma omp parallel for schedule(static) +#endif + for (size_t i = 0; i < m; i++) { + for (size_t j = 0; j < p; j++) { + __m256i sum_vec = _mm256_setzero_si256(); + size_t k = 0; + for (; k + 31 < n; k += 32) { + __m256i a_lo = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k])); + __m256i a_hi = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16])); + __m256i b_lo = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j])); + __m256i b_hi = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16])); + __m256i mul_lo = _mm256_madd_epi16(a_lo, b_lo); + __m256i mul_hi = _mm256_madd_epi16(a_hi, b_hi); + sum_vec = _mm256_add_epi32(sum_vec, mul_lo); + sum_vec = _mm256_add_epi32(sum_vec, mul_hi); + } + int s = reduce_add_i32x8(sum_vec); + for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j]; + if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale); + if (s > 127) + s = 127; + else if (s < -128) + s = -128; + C[i * p + j] = (int8_t)s; + } } - return impl(m, n, p, A, B, C, scale); + return 0; } -static int _matmul_f32_i8_i8(size_t m, size_t n, size_t p, const float *A, const int8_t *B, int8_t *C, double scale) { - static int (*impl)(size_t, size_t, size_t, const float *, const int8_t *, int8_t *, double) = NULL; - static int initialized = 0; - if (!initialized) { - impl = matmul_scalar_f32_i8_i8; - initialized = 1; +int matmul_avx2_u8_i8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, uint8_t *C, double scale) { +#ifdef _OPENMP +#pragma omp parallel for schedule(static) +#endif + for (size_t i = 0; i < m; i++) { + for (size_t j = 0; j < p; j++) { + __m256i sum_vec = _mm256_setzero_si256(); + size_t k = 0; + for (; k + 31 < n; k += 32) { + __m256i a_lo = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k])); + __m256i a_hi = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16])); + __m256i b_lo = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j])); + __m256i b_hi = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16])); + __m256i mul_lo = _mm256_madd_epi16(a_lo, b_lo); + __m256i mul_hi = _mm256_madd_epi16(a_hi, b_hi); + sum_vec = _mm256_add_epi32(sum_vec, mul_lo); + sum_vec = _mm256_add_epi32(sum_vec, mul_hi); + } + int s = reduce_add_i32x8(sum_vec); + for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j]; + if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale); + if (s > 255) + s = 255; + else if (s < 0) + s = 0; + C[i * p + j] = (uint8_t)s; + } } - return impl(m, n, p, A, B, C, scale); + return 0; } -static int _matmul_f32_i8_u8(size_t m, size_t n, size_t p, const float *A, const int8_t *B, uint8_t *C, double scale) { - static int (*impl)(size_t, size_t, size_t, const float *, const int8_t *, uint8_t *, double) = NULL; - static int initialized = 0; - if (!initialized) { - impl = matmul_scalar_f32_i8_u8; - initialized = 1; +int matmul_avx2_u8_u8_f32(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, float *C, double scale) { +#ifdef _OPENMP +#pragma omp parallel for schedule(static) +#endif + for (size_t i = 0; i < m; i++) { + for (size_t j = 0; j < p; j++) { + __m256i sum_vec = _mm256_setzero_si256(); + size_t k = 0; + for (; k + 31 < n; k += 32) { + __m256i a_lo = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k])); + __m256i a_hi = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16])); + __m256i b_lo = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j])); + __m256i b_hi = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16])); + __m256i mul_lo = _mm256_madd_epi16(a_lo, b_lo); + __m256i mul_hi = _mm256_madd_epi16(a_hi, b_hi); + sum_vec = _mm256_add_epi32(sum_vec, mul_lo); + sum_vec = _mm256_add_epi32(sum_vec, mul_hi); + } + int s = reduce_add_i32x8(sum_vec); + for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j]; + if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale); + C[i * p + j] = (float)s; + } } - return impl(m, n, p, A, B, C, scale); + return 0; } -static int _matmul_f32_u8_f32(size_t m, size_t n, size_t p, const float *A, const uint8_t *B, float *C, double scale) { - static int (*impl)(size_t, size_t, size_t, const float *, const uint8_t *, float *, double) = NULL; - static int initialized = 0; - if (!initialized) { - impl = matmul_scalar_f32_u8_f32; - initialized = 1; +int matmul_avx2_u8_u8_f64(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, double *C, double scale) { +#ifdef _OPENMP +#pragma omp parallel for schedule(static) +#endif + for (size_t i = 0; i < m; i++) { + for (size_t j = 0; j < p; j++) { + __m256i sum_vec = _mm256_setzero_si256(); + size_t k = 0; + for (; k + 31 < n; k += 32) { + __m256i a_lo = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k])); + __m256i a_hi = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16])); + __m256i b_lo = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j])); + __m256i b_hi = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16])); + __m256i mul_lo = _mm256_madd_epi16(a_lo, b_lo); + __m256i mul_hi = _mm256_madd_epi16(a_hi, b_hi); + sum_vec = _mm256_add_epi32(sum_vec, mul_lo); + sum_vec = _mm256_add_epi32(sum_vec, mul_hi); + } + int s = reduce_add_i32x8(sum_vec); + for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j]; + if (scale != 0 && scale != 1) s = (int)((double)s / scale); + C[i * p + j] = (double)s; + } } - return impl(m, n, p, A, B, C, scale); + return 0; } -static int _matmul_f32_u8_f64(size_t m, size_t n, size_t p, const float *A, const uint8_t *B, double *C, double scale) { - static int (*impl)(size_t, size_t, size_t, const float *, const uint8_t *, double *, double) = NULL; - static int initialized = 0; - if (!initialized) { - impl = matmul_scalar_f32_u8_f64; - initialized = 1; +int matmul_avx2_u8_u8_i8(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, int8_t *C, double scale) { +#ifdef _OPENMP +#pragma omp parallel for schedule(static) +#endif + for (size_t i = 0; i < m; i++) { + for (size_t j = 0; j < p; j++) { + __m256i sum_vec = _mm256_setzero_si256(); + size_t k = 0; + for (; k + 31 < n; k += 32) { + __m256i a_lo = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k])); + __m256i a_hi = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16])); + __m256i b_lo = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j])); + __m256i b_hi = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16])); + __m256i mul_lo = _mm256_madd_epi16(a_lo, b_lo); + __m256i mul_hi = _mm256_madd_epi16(a_hi, b_hi); + sum_vec = _mm256_add_epi32(sum_vec, mul_lo); + sum_vec = _mm256_add_epi32(sum_vec, mul_hi); + } + int s = reduce_add_i32x8(sum_vec); + for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j]; + if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale); + if (s > 127) + s = 127; + else if (s < -128) + s = -128; + C[i * p + j] = (int8_t)s; + } } - return impl(m, n, p, A, B, C, scale); + return 0; } -static int _matmul_f32_u8_i8(size_t m, size_t n, size_t p, const float *A, const uint8_t *B, int8_t *C, double scale) { - static int (*impl)(size_t, size_t, size_t, const float *, const uint8_t *, int8_t *, double) = NULL; - static int initialized = 0; - if (!initialized) { - impl = matmul_scalar_f32_u8_i8; - initialized = 1; +int matmul_avx2_u8_u8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, uint8_t *C, double scale) { +#ifdef _OPENMP +#pragma omp parallel for schedule(static) +#endif + for (size_t i = 0; i < m; i++) { + for (size_t j = 0; j < p; j++) { + __m256i sum_vec = _mm256_setzero_si256(); + size_t k = 0; + for (; k + 31 < n; k += 32) { + __m256i a_lo = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k])); + __m256i a_hi = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16])); + __m256i b_lo = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j])); + __m256i b_hi = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16])); + __m256i mul_lo = _mm256_madd_epi16(a_lo, b_lo); + __m256i mul_hi = _mm256_madd_epi16(a_hi, b_hi); + sum_vec = _mm256_add_epi32(sum_vec, mul_lo); + sum_vec = _mm256_add_epi32(sum_vec, mul_hi); + } + int s = reduce_add_i32x8(sum_vec); + for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j]; + if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale); + if (s > 255) + s = 255; + else if (s < 0) + s = 0; + C[i * p + j] = (uint8_t)s; + } } - return impl(m, n, p, A, B, C, scale); + return 0; } +#endif -static int _matmul_f32_u8_u8(size_t m, size_t n, size_t p, const float *A, const uint8_t *B, uint8_t *C, double scale) { - static int (*impl)(size_t, size_t, size_t, const float *, const uint8_t *, uint8_t *, double) = NULL; - static int initialized = 0; - if (!initialized) { - impl = matmul_scalar_f32_u8_u8; - initialized = 1; - } - return impl(m, n, p, A, B, C, scale); +#ifdef __AVX512F__ +static inline int32_t reduce_add_i32x16(__m512i v) { + __m256i low = _mm512_extracti64x4_epi64(v, 0); + __m256i high = _mm512_extracti64x4_epi64(v, 1); + __m256i sum = _mm256_add_epi32(low, high); + sum = _mm256_hadd_epi32(sum, sum); + sum = _mm256_hadd_epi32(sum, sum); + __m128i s128 = _mm256_extracti128_si256(sum, 0); + s128 = _mm_hadd_epi32(s128, s128); + s128 = _mm_hadd_epi32(s128, s128); + return _mm_cvtsi128_si32(s128); } -static int _matmul_f64_f32_f32(size_t m, size_t n, size_t p, const double *A, const float *B, float *C, double scale) { - static int (*impl)(size_t, size_t, size_t, const double *, const float *, float *, double) = NULL; - static int initialized = 0; - if (!initialized) { - matmul_feature_t feat = matmul_get_feature(); -#ifdef __AVX512F__ - if (feat & MATMUL_FLAG_AVX512) - impl = matmul_avx512_f64_f32_f32; - else -#endif -#ifdef __AVX2__ - if (feat & MATMUL_FLAG_AVX2) - impl = matmul_avx2_f64_f32_f32; - else +int matmul_avx512_i8_i8_f32(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, float *C, double scale) { +#ifdef _OPENMP +#pragma omp parallel for schedule(static) #endif - impl = matmul_scalar_f64_f32_f32; - initialized = 1; + for (size_t i = 0; i < m; i++) { + for (size_t j = 0; j < p; j++) { + __m512i sum_vec = _mm512_setzero_si512(); + size_t k = 0; + for (; k + 63 < n; k += 64) { + __m512i a0 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k])); + __m512i a1 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16])); + __m512i a2 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 32])); + __m512i a3 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 48])); + __m512i b0 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j])); + __m512i b1 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16])); + __m512i b2 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 32])); + __m512i b3 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 48])); + sum_vec = _mm512_dpwssd_epi32(sum_vec, a0, b0); + sum_vec = _mm512_dpwssd_epi32(sum_vec, a1, b1); + sum_vec = _mm512_dpwssd_epi32(sum_vec, a2, b2); + sum_vec = _mm512_dpwssd_epi32(sum_vec, a3, b3); + } + int s = reduce_add_i32x16(sum_vec); + for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j]; + if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale); + C[i * p + j] = (float)s; + } } - return impl(m, n, p, A, B, C, scale); + return 0; } -static int _matmul_f64_f32_f64(size_t m, size_t n, size_t p, const double *A, const float *B, double *C, double scale) { - static int (*impl)(size_t, size_t, size_t, const double *, const float *, double *, double) = NULL; - static int initialized = 0; - if (!initialized) { - matmul_feature_t feat = matmul_get_feature(); -#ifdef __AVX512F__ - if (feat & MATMUL_FLAG_AVX512) - impl = matmul_avx512_f64_f32_f64; - else +int matmul_avx512_i8_i8_f64(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, double *C, double scale) { +#ifdef _OPENMP +#pragma omp parallel for schedule(static) +#endif + for (size_t i = 0; i < m; i++) { + for (size_t j = 0; j < p; j++) { + __m512i sum_vec = _mm512_setzero_si512(); + size_t k = 0; + for (; k + 63 < n; k += 64) { + __m512i a0 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k])); + __m512i a1 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16])); + __m512i a2 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 32])); + __m512i a3 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 48])); + __m512i b0 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j])); + __m512i b1 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16])); + __m512i b2 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 32])); + __m512i b3 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 48])); + sum_vec = _mm512_dpwssd_epi32(sum_vec, a0, b0); + sum_vec = _mm512_dpwssd_epi32(sum_vec, a1, b1); + sum_vec = _mm512_dpwssd_epi32(sum_vec, a2, b2); + sum_vec = _mm512_dpwssd_epi32(sum_vec, a3, b3); + } + int s = reduce_add_i32x16(sum_vec); + for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j]; + if (scale != 0 && scale != 1) s = (int)((double)s / scale); + C[i * p + j] = (double)s; + } + } + return 0; +} + +int matmul_avx512_i8_i8_i8(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, int8_t *C, double scale) { +#ifdef _OPENMP +#pragma omp parallel for schedule(static) +#endif + for (size_t i = 0; i < m; i++) { + for (size_t j = 0; j < p; j++) { + __m512i sum_vec = _mm512_setzero_si512(); + size_t k = 0; + for (; k + 63 < n; k += 64) { + __m512i a0 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k])); + __m512i a1 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16])); + __m512i a2 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 32])); + __m512i a3 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 48])); + __m512i b0 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j])); + __m512i b1 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16])); + __m512i b2 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 32])); + __m512i b3 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 48])); + sum_vec = _mm512_dpwssd_epi32(sum_vec, a0, b0); + sum_vec = _mm512_dpwssd_epi32(sum_vec, a1, b1); + sum_vec = _mm512_dpwssd_epi32(sum_vec, a2, b2); + sum_vec = _mm512_dpwssd_epi32(sum_vec, a3, b3); + } + int s = reduce_add_i32x16(sum_vec); + for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j]; + if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale); + if (s > 127) + s = 127; + else if (s < -128) + s = -128; + C[i * p + j] = (int8_t)s; + } + } + return 0; +} + +int matmul_avx512_i8_i8_u8(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, uint8_t *C, double scale) { +#ifdef _OPENMP +#pragma omp parallel for schedule(static) +#endif + for (size_t i = 0; i < m; i++) { + for (size_t j = 0; j < p; j++) { + __m512i sum_vec = _mm512_setzero_si512(); + size_t k = 0; + for (; k + 63 < n; k += 64) { + __m512i a0 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k])); + __m512i a1 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16])); + __m512i a2 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 32])); + __m512i a3 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 48])); + __m512i b0 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j])); + __m512i b1 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16])); + __m512i b2 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 32])); + __m512i b3 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 48])); + sum_vec = _mm512_dpwssd_epi32(sum_vec, a0, b0); + sum_vec = _mm512_dpwssd_epi32(sum_vec, a1, b1); + sum_vec = _mm512_dpwssd_epi32(sum_vec, a2, b2); + sum_vec = _mm512_dpwssd_epi32(sum_vec, a3, b3); + } + int s = reduce_add_i32x16(sum_vec); + for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j]; + if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale); + if (s > 255) + s = 255; + else if (s < 0) + s = 0; + C[i * p + j] = (uint8_t)s; + } + } + return 0; +} + +int matmul_avx512_i8_u8_f32(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, float *C, double scale) { +#ifdef _OPENMP +#pragma omp parallel for schedule(static) +#endif + for (size_t i = 0; i < m; i++) { + for (size_t j = 0; j < p; j++) { + __m512i sum_vec = _mm512_setzero_si512(); + size_t k = 0; + for (; k + 63 < n; k += 64) { + __m512i a0 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k])); + __m512i a1 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16])); + __m512i a2 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 32])); + __m512i a3 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 48])); + __m512i b0 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j])); + __m512i b1 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16])); + __m512i b2 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 32])); + __m512i b3 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 48])); + sum_vec = _mm512_dpwssd_epi32(sum_vec, a0, b0); + sum_vec = _mm512_dpwssd_epi32(sum_vec, a1, b1); + sum_vec = _mm512_dpwssd_epi32(sum_vec, a2, b2); + sum_vec = _mm512_dpwssd_epi32(sum_vec, a3, b3); + } + int s = reduce_add_i32x16(sum_vec); + for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j]; + if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale); + C[i * p + j] = (float)s; + } + } + return 0; +} + +int matmul_avx512_i8_u8_f64(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, double *C, double scale) { +#ifdef _OPENMP +#pragma omp parallel for schedule(static) +#endif + for (size_t i = 0; i < m; i++) { + for (size_t j = 0; j < p; j++) { + __m512i sum_vec = _mm512_setzero_si512(); + size_t k = 0; + for (; k + 63 < n; k += 64) { + __m512i a0 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k])); + __m512i a1 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16])); + __m512i a2 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 32])); + __m512i a3 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 48])); + __m512i b0 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j])); + __m512i b1 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16])); + __m512i b2 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 32])); + __m512i b3 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 48])); + sum_vec = _mm512_dpwssd_epi32(sum_vec, a0, b0); + sum_vec = _mm512_dpwssd_epi32(sum_vec, a1, b1); + sum_vec = _mm512_dpwssd_epi32(sum_vec, a2, b2); + sum_vec = _mm512_dpwssd_epi32(sum_vec, a3, b3); + } + int s = reduce_add_i32x16(sum_vec); + for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j]; + if (scale != 0 && scale != 1) s = (int)((double)s / scale); + C[i * p + j] = (double)s; + } + } + return 0; +} + +int matmul_avx512_i8_u8_i8(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, int8_t *C, double scale) { +#ifdef _OPENMP +#pragma omp parallel for schedule(static) +#endif + for (size_t i = 0; i < m; i++) { + for (size_t j = 0; j < p; j++) { + __m512i sum_vec = _mm512_setzero_si512(); + size_t k = 0; + for (; k + 63 < n; k += 64) { + __m512i a0 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k])); + __m512i a1 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16])); + __m512i a2 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 32])); + __m512i a3 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 48])); + __m512i b0 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j])); + __m512i b1 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16])); + __m512i b2 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 32])); + __m512i b3 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 48])); + sum_vec = _mm512_dpwssd_epi32(sum_vec, a0, b0); + sum_vec = _mm512_dpwssd_epi32(sum_vec, a1, b1); + sum_vec = _mm512_dpwssd_epi32(sum_vec, a2, b2); + sum_vec = _mm512_dpwssd_epi32(sum_vec, a3, b3); + } + int s = reduce_add_i32x16(sum_vec); + for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j]; + if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale); + if (s > 127) + s = 127; + else if (s < -128) + s = -128; + C[i * p + j] = (int8_t)s; + } + } + return 0; +} + +int matmul_avx512_i8_u8_u8(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, uint8_t *C, double scale) { +#ifdef _OPENMP +#pragma omp parallel for schedule(static) +#endif + for (size_t i = 0; i < m; i++) { + for (size_t j = 0; j < p; j++) { + __m512i sum_vec = _mm512_setzero_si512(); + size_t k = 0; + for (; k + 63 < n; k += 64) { + __m512i a0 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k])); + __m512i a1 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16])); + __m512i a2 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 32])); + __m512i a3 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 48])); + __m512i b0 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j])); + __m512i b1 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16])); + __m512i b2 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 32])); + __m512i b3 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 48])); + sum_vec = _mm512_dpwssd_epi32(sum_vec, a0, b0); + sum_vec = _mm512_dpwssd_epi32(sum_vec, a1, b1); + sum_vec = _mm512_dpwssd_epi32(sum_vec, a2, b2); + sum_vec = _mm512_dpwssd_epi32(sum_vec, a3, b3); + } + int s = reduce_add_i32x16(sum_vec); + for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j]; + if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale); + if (s > 255) + s = 255; + else if (s < 0) + s = 0; + C[i * p + j] = (uint8_t)s; + } + } + return 0; +} + +int matmul_avx512_u8_i8_f32(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, float *C, double scale) { +#ifdef _OPENMP +#pragma omp parallel for schedule(static) +#endif + for (size_t i = 0; i < m; i++) { + for (size_t j = 0; j < p; j++) { + __m512i sum_vec = _mm512_setzero_si512(); + size_t k = 0; + for (; k + 63 < n; k += 64) { + __m512i a0 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k])); + __m512i a1 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16])); + __m512i a2 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 32])); + __m512i a3 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 48])); + __m512i b0 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j])); + __m512i b1 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16])); + __m512i b2 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 32])); + __m512i b3 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 48])); + sum_vec = _mm512_dpwssd_epi32(sum_vec, a0, b0); + sum_vec = _mm512_dpwssd_epi32(sum_vec, a1, b1); + sum_vec = _mm512_dpwssd_epi32(sum_vec, a2, b2); + sum_vec = _mm512_dpwssd_epi32(sum_vec, a3, b3); + } + int s = reduce_add_i32x16(sum_vec); + for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j]; + if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale); + C[i * p + j] = (float)s; + } + } + return 0; +} + +int matmul_avx512_u8_i8_f64(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, double *C, double scale) { +#ifdef _OPENMP +#pragma omp parallel for schedule(static) +#endif + for (size_t i = 0; i < m; i++) { + for (size_t j = 0; j < p; j++) { + __m512i sum_vec = _mm512_setzero_si512(); + size_t k = 0; + for (; k + 63 < n; k += 64) { + __m512i a0 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k])); + __m512i a1 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16])); + __m512i a2 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 32])); + __m512i a3 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 48])); + __m512i b0 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j])); + __m512i b1 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16])); + __m512i b2 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 32])); + __m512i b3 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 48])); + sum_vec = _mm512_dpwssd_epi32(sum_vec, a0, b0); + sum_vec = _mm512_dpwssd_epi32(sum_vec, a1, b1); + sum_vec = _mm512_dpwssd_epi32(sum_vec, a2, b2); + sum_vec = _mm512_dpwssd_epi32(sum_vec, a3, b3); + } + int s = reduce_add_i32x16(sum_vec); + for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j]; + if (scale != 0 && scale != 1) s = (int)((double)s / scale); + C[i * p + j] = (double)s; + } + } + return 0; +} + +int matmul_avx512_u8_i8_i8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, int8_t *C, double scale) { +#ifdef _OPENMP +#pragma omp parallel for schedule(static) +#endif + for (size_t i = 0; i < m; i++) { + for (size_t j = 0; j < p; j++) { + __m512i sum_vec = _mm512_setzero_si512(); + size_t k = 0; + for (; k + 63 < n; k += 64) { + __m512i a0 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k])); + __m512i a1 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16])); + __m512i a2 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 32])); + __m512i a3 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 48])); + __m512i b0 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j])); + __m512i b1 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16])); + __m512i b2 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 32])); + __m512i b3 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 48])); + sum_vec = _mm512_dpwssd_epi32(sum_vec, a0, b0); + sum_vec = _mm512_dpwssd_epi32(sum_vec, a1, b1); + sum_vec = _mm512_dpwssd_epi32(sum_vec, a2, b2); + sum_vec = _mm512_dpwssd_epi32(sum_vec, a3, b3); + } + int s = reduce_add_i32x16(sum_vec); + for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j]; + if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale); + if (s > 127) + s = 127; + else if (s < -128) + s = -128; + C[i * p + j] = (int8_t)s; + } + } + return 0; +} + +int matmul_avx512_u8_i8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, uint8_t *C, double scale) { +#ifdef _OPENMP +#pragma omp parallel for schedule(static) +#endif + for (size_t i = 0; i < m; i++) { + for (size_t j = 0; j < p; j++) { + __m512i sum_vec = _mm512_setzero_si512(); + size_t k = 0; + for (; k + 63 < n; k += 64) { + __m512i a0 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k])); + __m512i a1 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16])); + __m512i a2 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 32])); + __m512i a3 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 48])); + __m512i b0 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j])); + __m512i b1 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16])); + __m512i b2 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 32])); + __m512i b3 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 48])); + sum_vec = _mm512_dpwssd_epi32(sum_vec, a0, b0); + sum_vec = _mm512_dpwssd_epi32(sum_vec, a1, b1); + sum_vec = _mm512_dpwssd_epi32(sum_vec, a2, b2); + sum_vec = _mm512_dpwssd_epi32(sum_vec, a3, b3); + } + int s = reduce_add_i32x16(sum_vec); + for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j]; + if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale); + if (s > 255) + s = 255; + else if (s < 0) + s = 0; + C[i * p + j] = (uint8_t)s; + } + } + return 0; +} + +int matmul_avx512_u8_u8_f32(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, float *C, double scale) { +#ifdef _OPENMP +#pragma omp parallel for schedule(static) +#endif + for (size_t i = 0; i < m; i++) { + for (size_t j = 0; j < p; j++) { + __m512i sum_vec = _mm512_setzero_si512(); + size_t k = 0; + for (; k + 63 < n; k += 64) { + __m512i a0 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k])); + __m512i a1 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16])); + __m512i a2 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 32])); + __m512i a3 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 48])); + __m512i b0 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j])); + __m512i b1 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16])); + __m512i b2 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 32])); + __m512i b3 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 48])); + sum_vec = _mm512_dpwssd_epi32(sum_vec, a0, b0); + sum_vec = _mm512_dpwssd_epi32(sum_vec, a1, b1); + sum_vec = _mm512_dpwssd_epi32(sum_vec, a2, b2); + sum_vec = _mm512_dpwssd_epi32(sum_vec, a3, b3); + } + int s = reduce_add_i32x16(sum_vec); + for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j]; + if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale); + C[i * p + j] = (float)s; + } + } + return 0; +} + +int matmul_avx512_u8_u8_f64(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, double *C, double scale) { +#ifdef _OPENMP +#pragma omp parallel for schedule(static) +#endif + for (size_t i = 0; i < m; i++) { + for (size_t j = 0; j < p; j++) { + __m512i sum_vec = _mm512_setzero_si512(); + size_t k = 0; + for (; k + 63 < n; k += 64) { + __m512i a0 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k])); + __m512i a1 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16])); + __m512i a2 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 32])); + __m512i a3 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 48])); + __m512i b0 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j])); + __m512i b1 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16])); + __m512i b2 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 32])); + __m512i b3 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 48])); + sum_vec = _mm512_dpwssd_epi32(sum_vec, a0, b0); + sum_vec = _mm512_dpwssd_epi32(sum_vec, a1, b1); + sum_vec = _mm512_dpwssd_epi32(sum_vec, a2, b2); + sum_vec = _mm512_dpwssd_epi32(sum_vec, a3, b3); + } + int s = reduce_add_i32x16(sum_vec); + for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j]; + if (scale != 0 && scale != 1) s = (int)((double)s / scale); + C[i * p + j] = (double)s; + } + } + return 0; +} + +int matmul_avx512_u8_u8_i8(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, int8_t *C, double scale) { +#ifdef _OPENMP +#pragma omp parallel for schedule(static) +#endif + for (size_t i = 0; i < m; i++) { + for (size_t j = 0; j < p; j++) { + __m512i sum_vec = _mm512_setzero_si512(); + size_t k = 0; + for (; k + 63 < n; k += 64) { + __m512i a0 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k])); + __m512i a1 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16])); + __m512i a2 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 32])); + __m512i a3 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 48])); + __m512i b0 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j])); + __m512i b1 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16])); + __m512i b2 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 32])); + __m512i b3 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 48])); + sum_vec = _mm512_dpwssd_epi32(sum_vec, a0, b0); + sum_vec = _mm512_dpwssd_epi32(sum_vec, a1, b1); + sum_vec = _mm512_dpwssd_epi32(sum_vec, a2, b2); + sum_vec = _mm512_dpwssd_epi32(sum_vec, a3, b3); + } + int s = reduce_add_i32x16(sum_vec); + for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j]; + if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale); + if (s > 127) + s = 127; + else if (s < -128) + s = -128; + C[i * p + j] = (int8_t)s; + } + } + return 0; +} + +int matmul_avx512_u8_u8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, uint8_t *C, double scale) { +#ifdef _OPENMP +#pragma omp parallel for schedule(static) +#endif + for (size_t i = 0; i < m; i++) { + for (size_t j = 0; j < p; j++) { + __m512i sum_vec = _mm512_setzero_si512(); + size_t k = 0; + for (; k + 63 < n; k += 64) { + __m512i a0 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k])); + __m512i a1 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16])); + __m512i a2 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 32])); + __m512i a3 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 48])); + __m512i b0 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j])); + __m512i b1 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16])); + __m512i b2 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 32])); + __m512i b3 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 48])); + sum_vec = _mm512_dpwssd_epi32(sum_vec, a0, b0); + sum_vec = _mm512_dpwssd_epi32(sum_vec, a1, b1); + sum_vec = _mm512_dpwssd_epi32(sum_vec, a2, b2); + sum_vec = _mm512_dpwssd_epi32(sum_vec, a3, b3); + } + int s = reduce_add_i32x16(sum_vec); + for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j]; + if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale); + if (s > 255) + s = 255; + else if (s < 0) + s = 0; + C[i * p + j] = (uint8_t)s; + } + } + return 0; +} +#endif + +int matmul_scalar_u8_u8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, uint8_t *C, double scale) { +#ifdef _OPENMP +#pragma omp parallel for schedule(static) +#endif + for (size_t i = 0; i < m; i++) { + for (size_t j = 0; j < p; j++) { + int sum = 0; + for (size_t k = 0; k < n; k++) { + sum += (int)A[i * n + k] * (int)B[k * p + j]; + } + if (scale != 0 && scale != 1) sum = (int)(sum / scale); + if (sum > 255) + sum = 255; + else if (sum < 0) + sum = 0; + C[i * p + j] = (uint8_t)sum; + } + } + return 0; +} + +static int _matmul_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const float *B, float *C, double scale) { + static int initialized = 0; + if (!initialized) { + matmul_feature_t feat = matmul_get_feature(); +#ifdef __AVX512F__ + if (feat & MATMUL_FLAG_AVX512) + matmul_f32_f32_f32 = matmul_avx512_f32_f32_f32; + else +#endif +#ifdef __AVX2__ + if (feat & MATMUL_FLAG_AVX2) + matmul_f32_f32_f32 = matmul_avx2_f32_f32_f32; + else +#endif + matmul_f32_f32_f32 = matmul_scalar_f32_f32_f32; + initialized = 1; + } + return matmul_f32_f32_f32(m, n, p, A, B, C, scale); +} + +static int _matmul_f32_f32_f64(size_t m, size_t n, size_t p, const float *A, const float *B, double *C, double scale) { + static int initialized = 0; + if (!initialized) { + matmul_feature_t feat = matmul_get_feature(); +#ifdef __AVX512F__ + if (feat & MATMUL_FLAG_AVX512) + matmul_f32_f32_f64 = matmul_avx512_f32_f32_f64; + else +#endif +#ifdef __AVX2__ + if (feat & MATMUL_FLAG_AVX2) + matmul_f32_f32_f64 = matmul_avx2_f32_f32_f64; + else +#endif + matmul_f32_f32_f64 = matmul_scalar_f32_f32_f64; + initialized = 1; + } + return matmul_f32_f32_f64(m, n, p, A, B, C, scale); +} + +static int _matmul_f32_f32_i8(size_t m, size_t n, size_t p, const float *A, const float *B, int8_t *C, double scale) { + static int initialized = 0; + if (!initialized) { + matmul_f32_f32_i8 = matmul_scalar_f32_f32_i8; + initialized = 1; + } + return matmul_f32_f32_i8(m, n, p, A, B, C, scale); +} + +static int _matmul_f32_f32_u8(size_t m, size_t n, size_t p, const float *A, const float *B, uint8_t *C, double scale) { + static int initialized = 0; + if (!initialized) { + matmul_f32_f32_u8 = matmul_scalar_f32_f32_u8; + initialized = 1; + } + return matmul_f32_f32_u8(m, n, p, A, B, C, scale); +} + +static int _matmul_f32_f64_f32(size_t m, size_t n, size_t p, const float *A, const double *B, float *C, double scale) { + static int initialized = 0; + if (!initialized) { + matmul_feature_t feat = matmul_get_feature(); +#ifdef __AVX512F__ + if (feat & MATMUL_FLAG_AVX512) + matmul_f32_f64_f32 = matmul_avx512_f32_f64_f32; + else +#endif +#ifdef __AVX2__ + if (feat & MATMUL_FLAG_AVX2) + matmul_f32_f64_f32 = matmul_avx2_f32_f64_f32; + else +#endif + matmul_f32_f64_f32 = matmul_scalar_f32_f64_f32; + initialized = 1; + } + return matmul_f32_f64_f32(m, n, p, A, B, C, scale); +} + +static int _matmul_f32_f64_f64(size_t m, size_t n, size_t p, const float *A, const double *B, double *C, double scale) { + static int initialized = 0; + if (!initialized) { + matmul_feature_t feat = matmul_get_feature(); +#ifdef __AVX512F__ + if (feat & MATMUL_FLAG_AVX512) + matmul_f32_f64_f64 = matmul_avx512_f32_f64_f64; + else +#endif +#ifdef __AVX2__ + if (feat & MATMUL_FLAG_AVX2) + matmul_f32_f64_f64 = matmul_avx2_f32_f64_f64; + else +#endif + matmul_f32_f64_f64 = matmul_scalar_f32_f64_f64; + initialized = 1; + } + return matmul_f32_f64_f64(m, n, p, A, B, C, scale); +} + +static int _matmul_f32_f64_i8(size_t m, size_t n, size_t p, const float *A, const double *B, int8_t *C, double scale) { + static int initialized = 0; + if (!initialized) { + matmul_f32_f64_i8 = matmul_scalar_f32_f64_i8; + initialized = 1; + } + return matmul_f32_f64_i8(m, n, p, A, B, C, scale); +} + +static int _matmul_f32_f64_u8(size_t m, size_t n, size_t p, const float *A, const double *B, uint8_t *C, double scale) { + static int initialized = 0; + if (!initialized) { + matmul_f32_f64_u8 = matmul_scalar_f32_f64_u8; + initialized = 1; + } + return matmul_f32_f64_u8(m, n, p, A, B, C, scale); +} + +static int _matmul_f32_i8_f32(size_t m, size_t n, size_t p, const float *A, const int8_t *B, float *C, double scale) { + static int initialized = 0; + if (!initialized) { + matmul_f32_i8_f32 = matmul_scalar_f32_i8_f32; + initialized = 1; + } + return matmul_f32_i8_f32(m, n, p, A, B, C, scale); +} + +static int _matmul_f32_i8_f64(size_t m, size_t n, size_t p, const float *A, const int8_t *B, double *C, double scale) { + static int initialized = 0; + if (!initialized) { + matmul_f32_i8_f64 = matmul_scalar_f32_i8_f64; + initialized = 1; + } + return matmul_f32_i8_f64(m, n, p, A, B, C, scale); +} + +static int _matmul_f32_i8_i8(size_t m, size_t n, size_t p, const float *A, const int8_t *B, int8_t *C, double scale) { + static int initialized = 0; + if (!initialized) { + matmul_f32_i8_i8 = matmul_scalar_f32_i8_i8; + initialized = 1; + } + return matmul_f32_i8_i8(m, n, p, A, B, C, scale); +} + +static int _matmul_f32_i8_u8(size_t m, size_t n, size_t p, const float *A, const int8_t *B, uint8_t *C, double scale) { + static int initialized = 0; + if (!initialized) { + matmul_f32_i8_u8 = matmul_scalar_f32_i8_u8; + initialized = 1; + } + return matmul_f32_i8_u8(m, n, p, A, B, C, scale); +} + +static int _matmul_f32_u8_f32(size_t m, size_t n, size_t p, const float *A, const uint8_t *B, float *C, double scale) { + static int initialized = 0; + if (!initialized) { + matmul_f32_u8_f32 = matmul_scalar_f32_u8_f32; + initialized = 1; + } + return matmul_f32_u8_f32(m, n, p, A, B, C, scale); +} + +static int _matmul_f32_u8_f64(size_t m, size_t n, size_t p, const float *A, const uint8_t *B, double *C, double scale) { + static int initialized = 0; + if (!initialized) { + matmul_f32_u8_f64 = matmul_scalar_f32_u8_f64; + initialized = 1; + } + return matmul_f32_u8_f64(m, n, p, A, B, C, scale); +} + +static int _matmul_f32_u8_i8(size_t m, size_t n, size_t p, const float *A, const uint8_t *B, int8_t *C, double scale) { + static int initialized = 0; + if (!initialized) { + matmul_f32_u8_i8 = matmul_scalar_f32_u8_i8; + initialized = 1; + } + return matmul_f32_u8_i8(m, n, p, A, B, C, scale); +} + +static int _matmul_f32_u8_u8(size_t m, size_t n, size_t p, const float *A, const uint8_t *B, uint8_t *C, double scale) { + static int initialized = 0; + if (!initialized) { + matmul_f32_u8_u8 = matmul_scalar_f32_u8_u8; + initialized = 1; + } + return matmul_f32_u8_u8(m, n, p, A, B, C, scale); +} + +static int _matmul_f64_f32_f32(size_t m, size_t n, size_t p, const double *A, const float *B, float *C, double scale) { + static int initialized = 0; + if (!initialized) { + matmul_feature_t feat = matmul_get_feature(); +#ifdef __AVX512F__ + if (feat & MATMUL_FLAG_AVX512) + matmul_f64_f32_f32 = matmul_avx512_f64_f32_f32; + else +#endif +#ifdef __AVX2__ + if (feat & MATMUL_FLAG_AVX2) + matmul_f64_f32_f32 = matmul_avx2_f64_f32_f32; + else +#endif + matmul_f64_f32_f32 = matmul_scalar_f64_f32_f32; + initialized = 1; + } + return matmul_f64_f32_f32(m, n, p, A, B, C, scale); +} + +static int _matmul_f64_f32_f64(size_t m, size_t n, size_t p, const double *A, const float *B, double *C, double scale) { + static int initialized = 0; + if (!initialized) { + matmul_feature_t feat = matmul_get_feature(); +#ifdef __AVX512F__ + if (feat & MATMUL_FLAG_AVX512) + matmul_f64_f32_f64 = matmul_avx512_f64_f32_f64; + else #endif #ifdef __AVX2__ if (feat & MATMUL_FLAG_AVX2) - impl = matmul_avx2_f64_f32_f64; + matmul_f64_f32_f64 = matmul_avx2_f64_f32_f64; else #endif - impl = matmul_scalar_f64_f32_f64; + matmul_f64_f32_f64 = matmul_scalar_f64_f32_f64; initialized = 1; } - return impl(m, n, p, A, B, C, scale); + return matmul_f64_f32_f64(m, n, p, A, B, C, scale); } static int _matmul_f64_f32_i8(size_t m, size_t n, size_t p, const double *A, const float *B, int8_t *C, double scale) { - static int (*impl)(size_t, size_t, size_t, const double *, const float *, int8_t *, double) = NULL; - static int initialized = 0; + static int initialized = 0; if (!initialized) { - impl = matmul_scalar_f64_f32_i8; - initialized = 1; + matmul_f64_f32_i8 = matmul_scalar_f64_f32_i8; + initialized = 1; } - return impl(m, n, p, A, B, C, scale); + return matmul_f64_f32_i8(m, n, p, A, B, C, scale); } static int _matmul_f64_f32_u8(size_t m, size_t n, size_t p, const double *A, const float *B, uint8_t *C, double scale) { - static int (*impl)(size_t, size_t, size_t, const double *, const float *, uint8_t *, double) = NULL; - static int initialized = 0; + static int initialized = 0; if (!initialized) { - impl = matmul_scalar_f64_f32_u8; - initialized = 1; + matmul_f64_f32_u8 = matmul_scalar_f64_f32_u8; + initialized = 1; } - return impl(m, n, p, A, B, C, scale); + return matmul_f64_f32_u8(m, n, p, A, B, C, scale); } static int _matmul_f64_f64_f32(size_t m, size_t n, size_t p, const double *A, const double *B, float *C, double scale) { - static int (*impl)(size_t, size_t, size_t, const double *, const double *, float *, double) = NULL; - static int initialized = 0; + static int initialized = 0; if (!initialized) { matmul_feature_t feat = matmul_get_feature(); #ifdef __AVX512F__ if (feat & MATMUL_FLAG_AVX512) - impl = matmul_avx512_f64_f64_f32; + matmul_f64_f64_f32 = matmul_avx512_f64_f64_f32; else #endif #ifdef __AVX2__ if (feat & MATMUL_FLAG_AVX2) - impl = matmul_avx2_f64_f64_f32; + matmul_f64_f64_f32 = matmul_avx2_f64_f64_f32; else #endif - impl = matmul_scalar_f64_f64_f32; + matmul_f64_f64_f32 = matmul_scalar_f64_f64_f32; initialized = 1; } - return impl(m, n, p, A, B, C, scale); + return matmul_f64_f64_f32(m, n, p, A, B, C, scale); } static int _matmul_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, const double *B, double *C, double scale) { - static int (*impl)(size_t, size_t, size_t, const double *, const double *, double *, double) = NULL; - static int initialized = 0; + static int initialized = 0; if (!initialized) { matmul_feature_t feat = matmul_get_feature(); #ifdef __AVX512F__ if (feat & MATMUL_FLAG_AVX512) - impl = matmul_avx512_f64_f64_f64; + matmul_f64_f64_f64 = matmul_avx512_f64_f64_f64; else #endif #ifdef __AVX2__ if (feat & MATMUL_FLAG_AVX2) - impl = matmul_avx2_f64_f64_f64; + matmul_f64_f64_f64 = matmul_avx2_f64_f64_f64; else #endif - impl = matmul_scalar_f64_f64_f64; + matmul_f64_f64_f64 = matmul_scalar_f64_f64_f64; initialized = 1; } - return impl(m, n, p, A, B, C, scale); + return matmul_f64_f64_f64(m, n, p, A, B, C, scale); } static int _matmul_f64_f64_i8(size_t m, size_t n, size_t p, const double *A, const double *B, int8_t *C, double scale) { - static int (*impl)(size_t, size_t, size_t, const double *, const double *, int8_t *, double) = NULL; - static int initialized = 0; + static int initialized = 0; if (!initialized) { - impl = matmul_scalar_f64_f64_i8; - initialized = 1; + matmul_f64_f64_i8 = matmul_scalar_f64_f64_i8; + initialized = 1; } - return impl(m, n, p, A, B, C, scale); + return matmul_f64_f64_i8(m, n, p, A, B, C, scale); } static int _matmul_f64_f64_u8(size_t m, size_t n, size_t p, const double *A, const double *B, uint8_t *C, double scale) { - static int (*impl)(size_t, size_t, size_t, const double *, const double *, uint8_t *, double) = NULL; - static int initialized = 0; + static int initialized = 0; if (!initialized) { - impl = matmul_scalar_f64_f64_u8; - initialized = 1; + matmul_f64_f64_u8 = matmul_scalar_f64_f64_u8; + initialized = 1; } - return impl(m, n, p, A, B, C, scale); + return matmul_f64_f64_u8(m, n, p, A, B, C, scale); } static int _matmul_f64_i8_f32(size_t m, size_t n, size_t p, const double *A, const int8_t *B, float *C, double scale) { - static int (*impl)(size_t, size_t, size_t, const double *, const int8_t *, float *, double) = NULL; - static int initialized = 0; + static int initialized = 0; if (!initialized) { - impl = matmul_scalar_f64_i8_f32; - initialized = 1; + matmul_f64_i8_f32 = matmul_scalar_f64_i8_f32; + initialized = 1; } - return impl(m, n, p, A, B, C, scale); + return matmul_f64_i8_f32(m, n, p, A, B, C, scale); } static int _matmul_f64_i8_f64(size_t m, size_t n, size_t p, const double *A, const int8_t *B, double *C, double scale) { - static int (*impl)(size_t, size_t, size_t, const double *, const int8_t *, double *, double) = NULL; - static int initialized = 0; + static int initialized = 0; if (!initialized) { - impl = matmul_scalar_f64_i8_f64; - initialized = 1; + matmul_f64_i8_f64 = matmul_scalar_f64_i8_f64; + initialized = 1; } - return impl(m, n, p, A, B, C, scale); + return matmul_f64_i8_f64(m, n, p, A, B, C, scale); } static int _matmul_f64_i8_i8(size_t m, size_t n, size_t p, const double *A, const int8_t *B, int8_t *C, double scale) { - static int (*impl)(size_t, size_t, size_t, const double *, const int8_t *, int8_t *, double) = NULL; - static int initialized = 0; + static int initialized = 0; if (!initialized) { - impl = matmul_scalar_f64_i8_i8; - initialized = 1; + matmul_f64_i8_i8 = matmul_scalar_f64_i8_i8; + initialized = 1; } - return impl(m, n, p, A, B, C, scale); + return matmul_f64_i8_i8(m, n, p, A, B, C, scale); } static int _matmul_f64_i8_u8(size_t m, size_t n, size_t p, const double *A, const int8_t *B, uint8_t *C, double scale) { - static int (*impl)(size_t, size_t, size_t, const double *, const int8_t *, uint8_t *, double) = NULL; - static int initialized = 0; + static int initialized = 0; if (!initialized) { - impl = matmul_scalar_f64_i8_u8; - initialized = 1; + matmul_f64_i8_u8 = matmul_scalar_f64_i8_u8; + initialized = 1; } - return impl(m, n, p, A, B, C, scale); + return matmul_f64_i8_u8(m, n, p, A, B, C, scale); } static int _matmul_f64_u8_f32(size_t m, size_t n, size_t p, const double *A, const uint8_t *B, float *C, double scale) { - static int (*impl)(size_t, size_t, size_t, const double *, const uint8_t *, float *, double) = NULL; - static int initialized = 0; + static int initialized = 0; if (!initialized) { - impl = matmul_scalar_f64_u8_f32; - initialized = 1; + matmul_f64_u8_f32 = matmul_scalar_f64_u8_f32; + initialized = 1; } - return impl(m, n, p, A, B, C, scale); + return matmul_f64_u8_f32(m, n, p, A, B, C, scale); } static int _matmul_f64_u8_f64(size_t m, size_t n, size_t p, const double *A, const uint8_t *B, double *C, double scale) { - static int (*impl)(size_t, size_t, size_t, const double *, const uint8_t *, double *, double) = NULL; - static int initialized = 0; + static int initialized = 0; if (!initialized) { - impl = matmul_scalar_f64_u8_f64; - initialized = 1; + matmul_f64_u8_f64 = matmul_scalar_f64_u8_f64; + initialized = 1; } - return impl(m, n, p, A, B, C, scale); + return matmul_f64_u8_f64(m, n, p, A, B, C, scale); } static int _matmul_f64_u8_i8(size_t m, size_t n, size_t p, const double *A, const uint8_t *B, int8_t *C, double scale) { - static int (*impl)(size_t, size_t, size_t, const double *, const uint8_t *, int8_t *, double) = NULL; - static int initialized = 0; + static int initialized = 0; if (!initialized) { - impl = matmul_scalar_f64_u8_i8; - initialized = 1; + matmul_f64_u8_i8 = matmul_scalar_f64_u8_i8; + initialized = 1; } - return impl(m, n, p, A, B, C, scale); + return matmul_f64_u8_i8(m, n, p, A, B, C, scale); } static int _matmul_f64_u8_u8(size_t m, size_t n, size_t p, const double *A, const uint8_t *B, uint8_t *C, double scale) { - static int (*impl)(size_t, size_t, size_t, const double *, const uint8_t *, uint8_t *, double) = NULL; - static int initialized = 0; + static int initialized = 0; if (!initialized) { - impl = matmul_scalar_f64_u8_u8; - initialized = 1; + matmul_f64_u8_u8 = matmul_scalar_f64_u8_u8; + initialized = 1; } - return impl(m, n, p, A, B, C, scale); + return matmul_f64_u8_u8(m, n, p, A, B, C, scale); } static int _matmul_i8_f32_f32(size_t m, size_t n, size_t p, const int8_t *A, const float *B, float *C, double scale) { - static int (*impl)(size_t, size_t, size_t, const int8_t *, const float *, float *, double) = NULL; - static int initialized = 0; + static int initialized = 0; if (!initialized) { - impl = matmul_scalar_i8_f32_f32; - initialized = 1; + matmul_i8_f32_f32 = matmul_scalar_i8_f32_f32; + initialized = 1; } - return impl(m, n, p, A, B, C, scale); + return matmul_i8_f32_f32(m, n, p, A, B, C, scale); } static int _matmul_i8_f32_f64(size_t m, size_t n, size_t p, const int8_t *A, const float *B, double *C, double scale) { - static int (*impl)(size_t, size_t, size_t, const int8_t *, const float *, double *, double) = NULL; - static int initialized = 0; + static int initialized = 0; if (!initialized) { - impl = matmul_scalar_i8_f32_f64; - initialized = 1; + matmul_i8_f32_f64 = matmul_scalar_i8_f32_f64; + initialized = 1; } - return impl(m, n, p, A, B, C, scale); + return matmul_i8_f32_f64(m, n, p, A, B, C, scale); } static int _matmul_i8_f32_i8(size_t m, size_t n, size_t p, const int8_t *A, const float *B, int8_t *C, double scale) { - static int (*impl)(size_t, size_t, size_t, const int8_t *, const float *, int8_t *, double) = NULL; - static int initialized = 0; + static int initialized = 0; if (!initialized) { - impl = matmul_scalar_i8_f32_i8; - initialized = 1; + matmul_i8_f32_i8 = matmul_scalar_i8_f32_i8; + initialized = 1; } - return impl(m, n, p, A, B, C, scale); + return matmul_i8_f32_i8(m, n, p, A, B, C, scale); } static int _matmul_i8_f32_u8(size_t m, size_t n, size_t p, const int8_t *A, const float *B, uint8_t *C, double scale) { - static int (*impl)(size_t, size_t, size_t, const int8_t *, const float *, uint8_t *, double) = NULL; - static int initialized = 0; + static int initialized = 0; if (!initialized) { - impl = matmul_scalar_i8_f32_u8; - initialized = 1; + matmul_i8_f32_u8 = matmul_scalar_i8_f32_u8; + initialized = 1; } - return impl(m, n, p, A, B, C, scale); + return matmul_i8_f32_u8(m, n, p, A, B, C, scale); } static int _matmul_i8_f64_f32(size_t m, size_t n, size_t p, const int8_t *A, const double *B, float *C, double scale) { - static int (*impl)(size_t, size_t, size_t, const int8_t *, const double *, float *, double) = NULL; - static int initialized = 0; + static int initialized = 0; if (!initialized) { - impl = matmul_scalar_i8_f64_f32; - initialized = 1; + matmul_i8_f64_f32 = matmul_scalar_i8_f64_f32; + initialized = 1; } - return impl(m, n, p, A, B, C, scale); + return matmul_i8_f64_f32(m, n, p, A, B, C, scale); } static int _matmul_i8_f64_f64(size_t m, size_t n, size_t p, const int8_t *A, const double *B, double *C, double scale) { - static int (*impl)(size_t, size_t, size_t, const int8_t *, const double *, double *, double) = NULL; - static int initialized = 0; + static int initialized = 0; if (!initialized) { - impl = matmul_scalar_i8_f64_f64; - initialized = 1; + matmul_i8_f64_f64 = matmul_scalar_i8_f64_f64; + initialized = 1; } - return impl(m, n, p, A, B, C, scale); + return matmul_i8_f64_f64(m, n, p, A, B, C, scale); } static int _matmul_i8_f64_i8(size_t m, size_t n, size_t p, const int8_t *A, const double *B, int8_t *C, double scale) { - static int (*impl)(size_t, size_t, size_t, const int8_t *, const double *, int8_t *, double) = NULL; - static int initialized = 0; + static int initialized = 0; if (!initialized) { - impl = matmul_scalar_i8_f64_i8; - initialized = 1; + matmul_i8_f64_i8 = matmul_scalar_i8_f64_i8; + initialized = 1; } - return impl(m, n, p, A, B, C, scale); + return matmul_i8_f64_i8(m, n, p, A, B, C, scale); } static int _matmul_i8_f64_u8(size_t m, size_t n, size_t p, const int8_t *A, const double *B, uint8_t *C, double scale) { - static int (*impl)(size_t, size_t, size_t, const int8_t *, const double *, uint8_t *, double) = NULL; - static int initialized = 0; + static int initialized = 0; if (!initialized) { - impl = matmul_scalar_i8_f64_u8; - initialized = 1; + matmul_i8_f64_u8 = matmul_scalar_i8_f64_u8; + initialized = 1; } - return impl(m, n, p, A, B, C, scale); + return matmul_i8_f64_u8(m, n, p, A, B, C, scale); } static int _matmul_i8_i8_f32(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, float *C, double scale) { - static int (*impl)(size_t, size_t, size_t, const int8_t *, const int8_t *, float *, double) = NULL; - static int initialized = 0; + static int initialized = 0; if (!initialized) { - impl = matmul_scalar_i8_i8_f32; + matmul_feature_t feat = matmul_get_feature(); +#ifdef __AVX512F__ + if (feat & MATMUL_FLAG_AVX512_VNNI) + matmul_i8_i8_f32 = matmul_avx512_i8_i8_f32; + else +#endif +#ifdef __AVX2__ + if (feat & MATMUL_FLAG_AVXVNNI) + matmul_i8_i8_f32 = matmul_avx2_i8_i8_f32; + else +#endif + matmul_i8_i8_f32 = matmul_scalar_i8_i8_f32; initialized = 1; } - return impl(m, n, p, A, B, C, scale); + return matmul_i8_i8_f32(m, n, p, A, B, C, scale); } static int _matmul_i8_i8_f64(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, double *C, double scale) { - static int (*impl)(size_t, size_t, size_t, const int8_t *, const int8_t *, double *, double) = NULL; - static int initialized = 0; + static int initialized = 0; if (!initialized) { - impl = matmul_scalar_i8_i8_f64; + matmul_feature_t feat = matmul_get_feature(); +#ifdef __AVX512F__ + if (feat & MATMUL_FLAG_AVX512_VNNI) + matmul_i8_i8_f64 = matmul_avx512_i8_i8_f64; + else +#endif +#ifdef __AVX2__ + if (feat & MATMUL_FLAG_AVXVNNI) + matmul_i8_i8_f64 = matmul_avx2_i8_i8_f64; + else +#endif + matmul_i8_i8_f64 = matmul_scalar_i8_i8_f64; initialized = 1; } - return impl(m, n, p, A, B, C, scale); + return matmul_i8_i8_f64(m, n, p, A, B, C, scale); } static int _matmul_i8_i8_i8(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, int8_t *C, double scale) { - static int (*impl)(size_t, size_t, size_t, const int8_t *, const int8_t *, int8_t *, double) = NULL; - static int initialized = 0; + static int initialized = 0; if (!initialized) { - impl = matmul_scalar_i8_i8_i8; + matmul_feature_t feat = matmul_get_feature(); +#ifdef __AVX512F__ + if (feat & MATMUL_FLAG_AVX512_VNNI) + matmul_i8_i8_i8 = matmul_avx512_i8_i8_i8; + else +#endif +#ifdef __AVX2__ + if (feat & MATMUL_FLAG_AVXVNNI) + matmul_i8_i8_i8 = matmul_avx2_i8_i8_i8; + else +#endif + matmul_i8_i8_i8 = matmul_scalar_i8_i8_i8; initialized = 1; } - return impl(m, n, p, A, B, C, scale); + return matmul_i8_i8_i8(m, n, p, A, B, C, scale); } static int _matmul_i8_i8_u8(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, uint8_t *C, double scale) { - static int (*impl)(size_t, size_t, size_t, const int8_t *, const int8_t *, uint8_t *, double) = NULL; - static int initialized = 0; + static int initialized = 0; if (!initialized) { - impl = matmul_scalar_i8_i8_u8; + matmul_feature_t feat = matmul_get_feature(); +#ifdef __AVX512F__ + if (feat & MATMUL_FLAG_AVX512_VNNI) + matmul_i8_i8_u8 = matmul_avx512_i8_i8_u8; + else +#endif +#ifdef __AVX2__ + if (feat & MATMUL_FLAG_AVXVNNI) + matmul_i8_i8_u8 = matmul_avx2_i8_i8_u8; + else +#endif + matmul_i8_i8_u8 = matmul_scalar_i8_i8_u8; initialized = 1; } - return impl(m, n, p, A, B, C, scale); + return matmul_i8_i8_u8(m, n, p, A, B, C, scale); } static int _matmul_i8_u8_f32(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, float *C, double scale) { - static int (*impl)(size_t, size_t, size_t, const int8_t *, const uint8_t *, float *, double) = NULL; - static int initialized = 0; + static int initialized = 0; if (!initialized) { - impl = matmul_scalar_i8_u8_f32; + matmul_feature_t feat = matmul_get_feature(); +#ifdef __AVX512F__ + if (feat & MATMUL_FLAG_AVX512_VNNI) + matmul_i8_u8_f32 = matmul_avx512_i8_u8_f32; + else +#endif +#ifdef __AVX2__ + if (feat & MATMUL_FLAG_AVXVNNI) + matmul_i8_u8_f32 = matmul_avx2_i8_u8_f32; + else +#endif + matmul_i8_u8_f32 = matmul_scalar_i8_u8_f32; initialized = 1; } - return impl(m, n, p, A, B, C, scale); + return matmul_i8_u8_f32(m, n, p, A, B, C, scale); } static int _matmul_i8_u8_f64(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, double *C, double scale) { - static int (*impl)(size_t, size_t, size_t, const int8_t *, const uint8_t *, double *, double) = NULL; - static int initialized = 0; + static int initialized = 0; if (!initialized) { - impl = matmul_scalar_i8_u8_f64; + matmul_feature_t feat = matmul_get_feature(); +#ifdef __AVX512F__ + if (feat & MATMUL_FLAG_AVX512_VNNI) + matmul_i8_u8_f64 = matmul_avx512_i8_u8_f64; + else +#endif +#ifdef __AVX2__ + if (feat & MATMUL_FLAG_AVXVNNI) + matmul_i8_u8_f64 = matmul_avx2_i8_u8_f64; + else +#endif + matmul_i8_u8_f64 = matmul_scalar_i8_u8_f64; initialized = 1; } - return impl(m, n, p, A, B, C, scale); + return matmul_i8_u8_f64(m, n, p, A, B, C, scale); } static int _matmul_i8_u8_i8(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, int8_t *C, double scale) { - static int (*impl)(size_t, size_t, size_t, const int8_t *, const uint8_t *, int8_t *, double) = NULL; - static int initialized = 0; + static int initialized = 0; if (!initialized) { - impl = matmul_scalar_i8_u8_i8; + matmul_feature_t feat = matmul_get_feature(); +#ifdef __AVX512F__ + if (feat & MATMUL_FLAG_AVX512_VNNI) + matmul_i8_u8_i8 = matmul_avx512_i8_u8_i8; + else +#endif +#ifdef __AVX2__ + if (feat & MATMUL_FLAG_AVXVNNI) + matmul_i8_u8_i8 = matmul_avx2_i8_u8_i8; + else +#endif + matmul_i8_u8_i8 = matmul_scalar_i8_u8_i8; initialized = 1; } - return impl(m, n, p, A, B, C, scale); + return matmul_i8_u8_i8(m, n, p, A, B, C, scale); } static int _matmul_i8_u8_u8(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, uint8_t *C, double scale) { - static int (*impl)(size_t, size_t, size_t, const int8_t *, const uint8_t *, uint8_t *, double) = NULL; - static int initialized = 0; - if (!initialized) { - impl = matmul_scalar_i8_u8_u8; - initialized = 1; - } - return impl(m, n, p, A, B, C, scale); -} - -static int _matmul_u8_f32_f32(size_t m, size_t n, size_t p, const uint8_t *A, const float *B, float *C, double scale) { - static int (*impl)(size_t, size_t, size_t, const uint8_t *, const float *, float *, double) = NULL; - static int initialized = 0; + static int initialized = 0; if (!initialized) { - impl = matmul_scalar_u8_f32_f32; + matmul_feature_t feat = matmul_get_feature(); +#ifdef __AVX512F__ + if (feat & MATMUL_FLAG_AVX512_VNNI) + matmul_i8_u8_u8 = matmul_avx512_i8_u8_u8; + else +#endif +#ifdef __AVX2__ + if (feat & MATMUL_FLAG_AVXVNNI) + matmul_i8_u8_u8 = matmul_avx2_i8_u8_u8; + else +#endif + matmul_i8_u8_u8 = matmul_scalar_i8_u8_u8; initialized = 1; } - return impl(m, n, p, A, B, C, scale); + return matmul_i8_u8_u8(m, n, p, A, B, C, scale); } -static int _matmul_u8_f32_f64(size_t m, size_t n, size_t p, const uint8_t *A, const float *B, double *C, double scale) { - static int (*impl)(size_t, size_t, size_t, const uint8_t *, const float *, double *, double) = NULL; - static int initialized = 0; +static int _matmul_u8_i8_f32(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, float *C, double scale) { + static int initialized = 0; if (!initialized) { - impl = matmul_scalar_u8_f32_f64; + matmul_feature_t feat = matmul_get_feature(); +#ifdef __AVX512F__ + if (feat & MATMUL_FLAG_AVX512_VNNI) + matmul_u8_i8_f32 = matmul_avx512_u8_i8_f32; + else +#endif +#ifdef __AVX2__ + if (feat & MATMUL_FLAG_AVXVNNI) + matmul_u8_i8_f32 = matmul_avx2_u8_i8_f32; + else +#endif + matmul_u8_i8_f32 = matmul_scalar_u8_i8_f32; initialized = 1; } - return impl(m, n, p, A, B, C, scale); + return matmul_u8_i8_f32(m, n, p, A, B, C, scale); } -static int _matmul_u8_f32_i8(size_t m, size_t n, size_t p, const uint8_t *A, const float *B, int8_t *C, double scale) { - static int (*impl)(size_t, size_t, size_t, const uint8_t *, const float *, int8_t *, double) = NULL; - static int initialized = 0; +static int _matmul_u8_i8_f64(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, double *C, double scale) { + static int initialized = 0; if (!initialized) { - impl = matmul_scalar_u8_f32_i8; + matmul_feature_t feat = matmul_get_feature(); +#ifdef __AVX512F__ + if (feat & MATMUL_FLAG_AVX512_VNNI) + matmul_u8_i8_f64 = matmul_avx512_u8_i8_f64; + else +#endif +#ifdef __AVX2__ + if (feat & MATMUL_FLAG_AVXVNNI) + matmul_u8_i8_f64 = matmul_avx2_u8_i8_f64; + else +#endif + matmul_u8_i8_f64 = matmul_scalar_u8_i8_f64; initialized = 1; } - return impl(m, n, p, A, B, C, scale); + return matmul_u8_i8_f64(m, n, p, A, B, C, scale); } -static int _matmul_u8_f32_u8(size_t m, size_t n, size_t p, const uint8_t *A, const float *B, uint8_t *C, double scale) { - static int (*impl)(size_t, size_t, size_t, const uint8_t *, const float *, uint8_t *, double) = NULL; - static int initialized = 0; +static int _matmul_u8_i8_i8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, int8_t *C, double scale) { + static int initialized = 0; if (!initialized) { - impl = matmul_scalar_u8_f32_u8; + matmul_feature_t feat = matmul_get_feature(); +#ifdef __AVX512F__ + if (feat & MATMUL_FLAG_AVX512_VNNI) + matmul_u8_i8_i8 = matmul_avx512_u8_i8_i8; + else +#endif +#ifdef __AVX2__ + if (feat & MATMUL_FLAG_AVXVNNI) + matmul_u8_i8_i8 = matmul_avx2_u8_i8_i8; + else +#endif + matmul_u8_i8_i8 = matmul_scalar_u8_i8_i8; initialized = 1; } - return impl(m, n, p, A, B, C, scale); + return matmul_u8_i8_i8(m, n, p, A, B, C, scale); } -static int _matmul_u8_f64_f32(size_t m, size_t n, size_t p, const uint8_t *A, const double *B, float *C, double scale) { - static int (*impl)(size_t, size_t, size_t, const uint8_t *, const double *, float *, double) = NULL; - static int initialized = 0; +static int _matmul_u8_i8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, uint8_t *C, double scale) { + static int initialized = 0; if (!initialized) { - impl = matmul_scalar_u8_f64_f32; + matmul_feature_t feat = matmul_get_feature(); +#ifdef __AVX512F__ + if (feat & MATMUL_FLAG_AVX512_VNNI) + matmul_u8_i8_u8 = matmul_avx512_u8_i8_u8; + else +#endif +#ifdef __AVX2__ + if (feat & MATMUL_FLAG_AVXVNNI) + matmul_u8_i8_u8 = matmul_avx2_u8_i8_u8; + else +#endif + matmul_u8_i8_u8 = matmul_scalar_u8_i8_u8; initialized = 1; } - return impl(m, n, p, A, B, C, scale); + return matmul_u8_i8_u8(m, n, p, A, B, C, scale); } -static int _matmul_u8_f64_f64(size_t m, size_t n, size_t p, const uint8_t *A, const double *B, double *C, - double scale) { - static int (*impl)(size_t, size_t, size_t, const uint8_t *, const double *, double *, double) = NULL; - static int initialized = 0; +static int _matmul_u8_u8_f32(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, float *C, double scale) { + static int initialized = 0; if (!initialized) { - impl = matmul_scalar_u8_f64_f64; + matmul_feature_t feat = matmul_get_feature(); +#ifdef __AVX512F__ + if (feat & MATMUL_FLAG_AVX512_VNNI) + matmul_u8_u8_f32 = matmul_avx512_u8_u8_f32; + else +#endif +#ifdef __AVX2__ + if (feat & MATMUL_FLAG_AVXVNNI) + matmul_u8_u8_f32 = matmul_avx2_u8_u8_f32; + else +#endif + matmul_u8_u8_f32 = matmul_scalar_u8_u8_f32; initialized = 1; } - return impl(m, n, p, A, B, C, scale); + return matmul_u8_u8_f32(m, n, p, A, B, C, scale); } -static int _matmul_u8_f64_i8(size_t m, size_t n, size_t p, const uint8_t *A, const double *B, int8_t *C, double scale) { - static int (*impl)(size_t, size_t, size_t, const uint8_t *, const double *, int8_t *, double) = NULL; - static int initialized = 0; +static int _matmul_u8_u8_f64(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, double *C, + double scale) { + static int initialized = 0; if (!initialized) { - impl = matmul_scalar_u8_f64_i8; + matmul_feature_t feat = matmul_get_feature(); +#ifdef __AVX512F__ + if (feat & MATMUL_FLAG_AVX512_VNNI) + matmul_u8_u8_f64 = matmul_avx512_u8_u8_f64; + else +#endif +#ifdef __AVX2__ + if (feat & MATMUL_FLAG_AVXVNNI) + matmul_u8_u8_f64 = matmul_avx2_u8_u8_f64; + else +#endif + matmul_u8_u8_f64 = matmul_scalar_u8_u8_f64; initialized = 1; } - return impl(m, n, p, A, B, C, scale); + return matmul_u8_u8_f64(m, n, p, A, B, C, scale); } -static int _matmul_u8_f64_u8(size_t m, size_t n, size_t p, const uint8_t *A, const double *B, uint8_t *C, - double scale) { - static int (*impl)(size_t, size_t, size_t, const uint8_t *, const double *, uint8_t *, double) = NULL; - static int initialized = 0; +static int _matmul_u8_u8_i8(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, int8_t *C, double scale) { + static int initialized = 0; if (!initialized) { - impl = matmul_scalar_u8_f64_u8; + matmul_feature_t feat = matmul_get_feature(); +#ifdef __AVX512F__ + if (feat & MATMUL_FLAG_AVX512_VNNI) + matmul_u8_u8_i8 = matmul_avx512_u8_u8_i8; + else +#endif +#ifdef __AVX2__ + if (feat & MATMUL_FLAG_AVXVNNI) + matmul_u8_u8_i8 = matmul_avx2_u8_u8_i8; + else +#endif + matmul_u8_u8_i8 = matmul_scalar_u8_u8_i8; initialized = 1; } - return impl(m, n, p, A, B, C, scale); + return matmul_u8_u8_i8(m, n, p, A, B, C, scale); } -static int _matmul_u8_i8_f32(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, float *C, double scale) { - static int (*impl)(size_t, size_t, size_t, const uint8_t *, const int8_t *, float *, double) = NULL; - static int initialized = 0; +static int _matmul_u8_u8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, uint8_t *C, + double scale) { + static int initialized = 0; if (!initialized) { - impl = matmul_scalar_u8_i8_f32; + matmul_feature_t feat = matmul_get_feature(); +#ifdef __AVX512F__ + if (feat & MATMUL_FLAG_AVX512_VNNI) + matmul_u8_u8_u8 = matmul_avx512_u8_u8_u8; + else +#endif +#ifdef __AVX2__ + if (feat & MATMUL_FLAG_AVXVNNI) + matmul_u8_u8_u8 = matmul_avx2_u8_u8_u8; + else +#endif + matmul_u8_u8_u8 = matmul_scalar_u8_u8_u8; initialized = 1; } - return impl(m, n, p, A, B, C, scale); + return matmul_u8_u8_u8(m, n, p, A, B, C, scale); } -static int _matmul_u8_i8_f64(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, double *C, double scale) { - static int (*impl)(size_t, size_t, size_t, const uint8_t *, const int8_t *, double *, double) = NULL; - static int initialized = 0; +static int _matmul_u8_f32_f32(size_t m, size_t n, size_t p, const uint8_t *A, const float *B, float *C, double scale) { + static int initialized = 0; if (!initialized) { - impl = matmul_scalar_u8_i8_f64; - initialized = 1; + matmul_u8_f32_f32 = matmul_scalar_u8_f32_f32; + initialized = 1; } - return impl(m, n, p, A, B, C, scale); + return matmul_u8_f32_f32(m, n, p, A, B, C, scale); } -static int _matmul_u8_i8_i8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, int8_t *C, double scale) { - static int (*impl)(size_t, size_t, size_t, const uint8_t *, const int8_t *, int8_t *, double) = NULL; - static int initialized = 0; +static int _matmul_u8_f32_f64(size_t m, size_t n, size_t p, const uint8_t *A, const float *B, double *C, double scale) { + static int initialized = 0; if (!initialized) { - impl = matmul_scalar_u8_i8_i8; - initialized = 1; + matmul_u8_f32_f64 = matmul_scalar_u8_f32_f64; + initialized = 1; } - return impl(m, n, p, A, B, C, scale); + return matmul_u8_f32_f64(m, n, p, A, B, C, scale); } -static int _matmul_u8_i8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, uint8_t *C, double scale) { - static int (*impl)(size_t, size_t, size_t, const uint8_t *, const int8_t *, uint8_t *, double) = NULL; - static int initialized = 0; +static int _matmul_u8_f32_i8(size_t m, size_t n, size_t p, const uint8_t *A, const float *B, int8_t *C, double scale) { + static int initialized = 0; if (!initialized) { - impl = matmul_scalar_u8_i8_u8; - initialized = 1; + matmul_u8_f32_i8 = matmul_scalar_u8_f32_i8; + initialized = 1; } - return impl(m, n, p, A, B, C, scale); + return matmul_u8_f32_i8(m, n, p, A, B, C, scale); } -static int _matmul_u8_u8_f32(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, float *C, double scale) { - static int (*impl)(size_t, size_t, size_t, const uint8_t *, const uint8_t *, float *, double) = NULL; - static int initialized = 0; +static int _matmul_u8_f32_u8(size_t m, size_t n, size_t p, const uint8_t *A, const float *B, uint8_t *C, double scale) { + static int initialized = 0; if (!initialized) { - impl = matmul_scalar_u8_u8_f32; - initialized = 1; + matmul_u8_f32_u8 = matmul_scalar_u8_f32_u8; + initialized = 1; } - return impl(m, n, p, A, B, C, scale); + return matmul_u8_f32_u8(m, n, p, A, B, C, scale); } -static int _matmul_u8_u8_f64(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, double *C, - double scale) { - static int (*impl)(size_t, size_t, size_t, const uint8_t *, const uint8_t *, double *, double) = NULL; - static int initialized = 0; +static int _matmul_u8_f64_f32(size_t m, size_t n, size_t p, const uint8_t *A, const double *B, float *C, double scale) { + static int initialized = 0; if (!initialized) { - impl = matmul_scalar_u8_u8_f64; - initialized = 1; + matmul_u8_f64_f32 = matmul_scalar_u8_f64_f32; + initialized = 1; } - return impl(m, n, p, A, B, C, scale); + return matmul_u8_f64_f32(m, n, p, A, B, C, scale); } -static int _matmul_u8_u8_i8(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, int8_t *C, double scale) { - static int (*impl)(size_t, size_t, size_t, const uint8_t *, const uint8_t *, int8_t *, double) = NULL; - static int initialized = 0; +static int _matmul_u8_f64_f64(size_t m, size_t n, size_t p, const uint8_t *A, const double *B, double *C, + double scale) { + static int initialized = 0; if (!initialized) { - impl = matmul_scalar_u8_u8_i8; - initialized = 1; + matmul_u8_f64_f64 = matmul_scalar_u8_f64_f64; + initialized = 1; } - return impl(m, n, p, A, B, C, scale); + return matmul_u8_f64_f64(m, n, p, A, B, C, scale); } -static int _matmul_u8_u8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, uint8_t *C, - double scale) { - static int (*impl)(size_t, size_t, size_t, const uint8_t *, const uint8_t *, uint8_t *, double) = NULL; - static int initialized = 0; +static int _matmul_u8_f64_i8(size_t m, size_t n, size_t p, const uint8_t *A, const double *B, int8_t *C, double scale) { + static int initialized = 0; if (!initialized) { - impl = matmul_scalar_u8_u8_u8; - initialized = 1; + matmul_u8_f64_i8 = matmul_scalar_u8_f64_i8; + initialized = 1; } - return impl(m, n, p, A, B, C, scale); + return matmul_u8_f64_i8(m, n, p, A, B, C, scale); } -int matmul_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const float *B, float *C, double scale) { - return _matmul_f32_f32_f32(m, n, p, A, B, C, scale); -} -int matmul_f32_f32_f64(size_t m, size_t n, size_t p, const float *A, const float *B, double *C, double scale) { - return _matmul_f32_f32_f64(m, n, p, A, B, C, scale); -} -int matmul_f32_f32_i8(size_t m, size_t n, size_t p, const float *A, const float *B, int8_t *C, double scale) { - return _matmul_f32_f32_i8(m, n, p, A, B, C, scale); -} -int matmul_f32_f32_u8(size_t m, size_t n, size_t p, const float *A, const float *B, uint8_t *C, double scale) { - return _matmul_f32_f32_u8(m, n, p, A, B, C, scale); -} -int matmul_f32_f64_f32(size_t m, size_t n, size_t p, const float *A, const double *B, float *C, double scale) { - return _matmul_f32_f64_f32(m, n, p, A, B, C, scale); -} -int matmul_f32_f64_f64(size_t m, size_t n, size_t p, const float *A, const double *B, double *C, double scale) { - return _matmul_f32_f64_f64(m, n, p, A, B, C, scale); -} -int matmul_f32_f64_i8(size_t m, size_t n, size_t p, const float *A, const double *B, int8_t *C, double scale) { - return _matmul_f32_f64_i8(m, n, p, A, B, C, scale); -} -int matmul_f32_f64_u8(size_t m, size_t n, size_t p, const float *A, const double *B, uint8_t *C, double scale) { - return _matmul_f32_f64_u8(m, n, p, A, B, C, scale); -} -int matmul_f32_i8_f32(size_t m, size_t n, size_t p, const float *A, const int8_t *B, float *C, double scale) { - return _matmul_f32_i8_f32(m, n, p, A, B, C, scale); -} -int matmul_f32_i8_f64(size_t m, size_t n, size_t p, const float *A, const int8_t *B, double *C, double scale) { - return _matmul_f32_i8_f64(m, n, p, A, B, C, scale); -} -int matmul_f32_i8_i8(size_t m, size_t n, size_t p, const float *A, const int8_t *B, int8_t *C, double scale) { - return _matmul_f32_i8_i8(m, n, p, A, B, C, scale); -} -int matmul_f32_i8_u8(size_t m, size_t n, size_t p, const float *A, const int8_t *B, uint8_t *C, double scale) { - return _matmul_f32_i8_u8(m, n, p, A, B, C, scale); -} -int matmul_f32_u8_f32(size_t m, size_t n, size_t p, const float *A, const uint8_t *B, float *C, double scale) { - return _matmul_f32_u8_f32(m, n, p, A, B, C, scale); -} -int matmul_f32_u8_f64(size_t m, size_t n, size_t p, const float *A, const uint8_t *B, double *C, double scale) { - return _matmul_f32_u8_f64(m, n, p, A, B, C, scale); -} -int matmul_f32_u8_i8(size_t m, size_t n, size_t p, const float *A, const uint8_t *B, int8_t *C, double scale) { - return _matmul_f32_u8_i8(m, n, p, A, B, C, scale); -} -int matmul_f32_u8_u8(size_t m, size_t n, size_t p, const float *A, const uint8_t *B, uint8_t *C, double scale) { - return _matmul_f32_u8_u8(m, n, p, A, B, C, scale); -} -int matmul_f64_f32_f32(size_t m, size_t n, size_t p, const double *A, const float *B, float *C, double scale) { - return _matmul_f64_f32_f32(m, n, p, A, B, C, scale); -} -int matmul_f64_f32_f64(size_t m, size_t n, size_t p, const double *A, const float *B, double *C, double scale) { - return _matmul_f64_f32_f64(m, n, p, A, B, C, scale); -} -int matmul_f64_f32_i8(size_t m, size_t n, size_t p, const double *A, const float *B, int8_t *C, double scale) { - return _matmul_f64_f32_i8(m, n, p, A, B, C, scale); -} -int matmul_f64_f32_u8(size_t m, size_t n, size_t p, const double *A, const float *B, uint8_t *C, double scale) { - return _matmul_f64_f32_u8(m, n, p, A, B, C, scale); -} -int matmul_f64_f64_f32(size_t m, size_t n, size_t p, const double *A, const double *B, float *C, double scale) { - return _matmul_f64_f64_f32(m, n, p, A, B, C, scale); -} -int matmul_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, const double *B, double *C, double scale) { - return _matmul_f64_f64_f64(m, n, p, A, B, C, scale); -} -int matmul_f64_f64_i8(size_t m, size_t n, size_t p, const double *A, const double *B, int8_t *C, double scale) { - return _matmul_f64_f64_i8(m, n, p, A, B, C, scale); -} -int matmul_f64_f64_u8(size_t m, size_t n, size_t p, const double *A, const double *B, uint8_t *C, double scale) { - return _matmul_f64_f64_u8(m, n, p, A, B, C, scale); -} -int matmul_f64_i8_f32(size_t m, size_t n, size_t p, const double *A, const int8_t *B, float *C, double scale) { - return _matmul_f64_i8_f32(m, n, p, A, B, C, scale); -} -int matmul_f64_i8_f64(size_t m, size_t n, size_t p, const double *A, const int8_t *B, double *C, double scale) { - return _matmul_f64_i8_f64(m, n, p, A, B, C, scale); -} -int matmul_f64_i8_i8(size_t m, size_t n, size_t p, const double *A, const int8_t *B, int8_t *C, double scale) { - return _matmul_f64_i8_i8(m, n, p, A, B, C, scale); -} -int matmul_f64_i8_u8(size_t m, size_t n, size_t p, const double *A, const int8_t *B, uint8_t *C, double scale) { - return _matmul_f64_i8_u8(m, n, p, A, B, C, scale); -} -int matmul_f64_u8_f32(size_t m, size_t n, size_t p, const double *A, const uint8_t *B, float *C, double scale) { - return _matmul_f64_u8_f32(m, n, p, A, B, C, scale); -} -int matmul_f64_u8_f64(size_t m, size_t n, size_t p, const double *A, const uint8_t *B, double *C, double scale) { - return _matmul_f64_u8_f64(m, n, p, A, B, C, scale); -} -int matmul_f64_u8_i8(size_t m, size_t n, size_t p, const double *A, const uint8_t *B, int8_t *C, double scale) { - return _matmul_f64_u8_i8(m, n, p, A, B, C, scale); -} -int matmul_f64_u8_u8(size_t m, size_t n, size_t p, const double *A, const uint8_t *B, uint8_t *C, double scale) { - return _matmul_f64_u8_u8(m, n, p, A, B, C, scale); -} -int matmul_i8_f32_f32(size_t m, size_t n, size_t p, const int8_t *A, const float *B, float *C, double scale) { - return _matmul_i8_f32_f32(m, n, p, A, B, C, scale); -} -int matmul_i8_f32_f64(size_t m, size_t n, size_t p, const int8_t *A, const float *B, double *C, double scale) { - return _matmul_i8_f32_f64(m, n, p, A, B, C, scale); -} -int matmul_i8_f32_i8(size_t m, size_t n, size_t p, const int8_t *A, const float *B, int8_t *C, double scale) { - return _matmul_i8_f32_i8(m, n, p, A, B, C, scale); -} -int matmul_i8_f32_u8(size_t m, size_t n, size_t p, const int8_t *A, const float *B, uint8_t *C, double scale) { - return _matmul_i8_f32_u8(m, n, p, A, B, C, scale); -} -int matmul_i8_f64_f32(size_t m, size_t n, size_t p, const int8_t *A, const double *B, float *C, double scale) { - return _matmul_i8_f64_f32(m, n, p, A, B, C, scale); -} -int matmul_i8_f64_f64(size_t m, size_t n, size_t p, const int8_t *A, const double *B, double *C, double scale) { - return _matmul_i8_f64_f64(m, n, p, A, B, C, scale); -} -int matmul_i8_f64_i8(size_t m, size_t n, size_t p, const int8_t *A, const double *B, int8_t *C, double scale) { - return _matmul_i8_f64_i8(m, n, p, A, B, C, scale); -} -int matmul_i8_f64_u8(size_t m, size_t n, size_t p, const int8_t *A, const double *B, uint8_t *C, double scale) { - return _matmul_i8_f64_u8(m, n, p, A, B, C, scale); -} -int matmul_i8_i8_f32(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, float *C, double scale) { - return _matmul_i8_i8_f32(m, n, p, A, B, C, scale); -} -int matmul_i8_i8_f64(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, double *C, double scale) { - return _matmul_i8_i8_f64(m, n, p, A, B, C, scale); -} -int matmul_i8_i8_i8(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, int8_t *C, double scale) { - return _matmul_i8_i8_i8(m, n, p, A, B, C, scale); -} -int matmul_i8_i8_u8(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, uint8_t *C, double scale) { - return _matmul_i8_i8_u8(m, n, p, A, B, C, scale); -} -int matmul_i8_u8_f32(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, float *C, double scale) { - return _matmul_i8_u8_f32(m, n, p, A, B, C, scale); -} -int matmul_i8_u8_f64(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, double *C, double scale) { - return _matmul_i8_u8_f64(m, n, p, A, B, C, scale); -} -int matmul_i8_u8_i8(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, int8_t *C, double scale) { - return _matmul_i8_u8_i8(m, n, p, A, B, C, scale); -} -int matmul_i8_u8_u8(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, uint8_t *C, double scale) { - return _matmul_i8_u8_u8(m, n, p, A, B, C, scale); -} -int matmul_u8_f32_f32(size_t m, size_t n, size_t p, const uint8_t *A, const float *B, float *C, double scale) { - return _matmul_u8_f32_f32(m, n, p, A, B, C, scale); -} -int matmul_u8_f32_f64(size_t m, size_t n, size_t p, const uint8_t *A, const float *B, double *C, double scale) { - return _matmul_u8_f32_f64(m, n, p, A, B, C, scale); -} -int matmul_u8_f32_i8(size_t m, size_t n, size_t p, const uint8_t *A, const float *B, int8_t *C, double scale) { - return _matmul_u8_f32_i8(m, n, p, A, B, C, scale); -} -int matmul_u8_f32_u8(size_t m, size_t n, size_t p, const uint8_t *A, const float *B, uint8_t *C, double scale) { - return _matmul_u8_f32_u8(m, n, p, A, B, C, scale); -} -int matmul_u8_f64_f32(size_t m, size_t n, size_t p, const uint8_t *A, const double *B, float *C, double scale) { - return _matmul_u8_f64_f32(m, n, p, A, B, C, scale); -} -int matmul_u8_f64_f64(size_t m, size_t n, size_t p, const uint8_t *A, const double *B, double *C, double scale) { - return _matmul_u8_f64_f64(m, n, p, A, B, C, scale); -} -int matmul_u8_f64_i8(size_t m, size_t n, size_t p, const uint8_t *A, const double *B, int8_t *C, double scale) { - return _matmul_u8_f64_i8(m, n, p, A, B, C, scale); -} -int matmul_u8_f64_u8(size_t m, size_t n, size_t p, const uint8_t *A, const double *B, uint8_t *C, double scale) { - return _matmul_u8_f64_u8(m, n, p, A, B, C, scale); -} -int matmul_u8_i8_f32(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, float *C, double scale) { - return _matmul_u8_i8_f32(m, n, p, A, B, C, scale); -} -int matmul_u8_i8_f64(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, double *C, double scale) { - return _matmul_u8_i8_f64(m, n, p, A, B, C, scale); -} -int matmul_u8_i8_i8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, int8_t *C, double scale) { - return _matmul_u8_i8_i8(m, n, p, A, B, C, scale); -} -int matmul_u8_i8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, uint8_t *C, double scale) { - return _matmul_u8_i8_u8(m, n, p, A, B, C, scale); -} -int matmul_u8_u8_f32(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, float *C, double scale) { - return _matmul_u8_u8_f32(m, n, p, A, B, C, scale); -} -int matmul_u8_u8_f64(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, double *C, double scale) { - return _matmul_u8_u8_f64(m, n, p, A, B, C, scale); -} -int matmul_u8_u8_i8(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, int8_t *C, double scale) { - return _matmul_u8_u8_i8(m, n, p, A, B, C, scale); -} -int matmul_u8_u8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, uint8_t *C, double scale) { - return _matmul_u8_u8_u8(m, n, p, A, B, C, scale); -} +static int _matmul_u8_f64_u8(size_t m, size_t n, size_t p, const uint8_t *A, const double *B, uint8_t *C, + double scale) { + static int initialized = 0; + if (!initialized) { + matmul_u8_f64_u8 = matmul_scalar_u8_f64_u8; + initialized = 1; + } + return matmul_u8_f64_u8(m, n, p, A, B, C, scale); +} + +int (*matmul_f32_f32_f32)(size_t, size_t, size_t, const float *, const float *, float *, double) = _matmul_f32_f32_f32; +int (*matmul_f32_f32_f64)(size_t, size_t, size_t, const float *, const float *, double *, double) = _matmul_f32_f32_f64; +int (*matmul_f32_f32_i8)(size_t, size_t, size_t, const float *, const float *, int8_t *, double) = _matmul_f32_f32_i8; +int (*matmul_f32_f32_u8)(size_t, size_t, size_t, const float *, const float *, uint8_t *, double) = _matmul_f32_f32_u8; +int (*matmul_f32_f64_f32)(size_t, size_t, size_t, const float *, const double *, float *, double) = _matmul_f32_f64_f32; +int (*matmul_f32_f64_f64)(size_t, size_t, size_t, const float *, const double *, double *, + double) = _matmul_f32_f64_f64; +int (*matmul_f32_f64_i8)(size_t, size_t, size_t, const float *, const double *, int8_t *, double) = _matmul_f32_f64_i8; +int (*matmul_f32_f64_u8)(size_t, size_t, size_t, const float *, const double *, uint8_t *, double) = _matmul_f32_f64_u8; +int (*matmul_f32_i8_f32)(size_t, size_t, size_t, const float *, const int8_t *, float *, double) = _matmul_f32_i8_f32; +int (*matmul_f32_i8_f64)(size_t, size_t, size_t, const float *, const int8_t *, double *, double) = _matmul_f32_i8_f64; +int (*matmul_f32_i8_i8)(size_t, size_t, size_t, const float *, const int8_t *, int8_t *, double) = _matmul_f32_i8_i8; +int (*matmul_f32_i8_u8)(size_t, size_t, size_t, const float *, const int8_t *, uint8_t *, double) = _matmul_f32_i8_u8; +int (*matmul_f32_u8_f32)(size_t, size_t, size_t, const float *, const uint8_t *, float *, double) = _matmul_f32_u8_f32; +int (*matmul_f32_u8_f64)(size_t, size_t, size_t, const float *, const uint8_t *, double *, double) = _matmul_f32_u8_f64; +int (*matmul_f32_u8_i8)(size_t, size_t, size_t, const float *, const uint8_t *, int8_t *, double) = _matmul_f32_u8_i8; +int (*matmul_f32_u8_u8)(size_t, size_t, size_t, const float *, const uint8_t *, uint8_t *, double) = _matmul_f32_u8_u8; +int (*matmul_f64_f32_f32)(size_t, size_t, size_t, const double *, const float *, float *, double) = _matmul_f64_f32_f32; +int (*matmul_f64_f32_f64)(size_t, size_t, size_t, const double *, const float *, double *, + double) = _matmul_f64_f32_f64; +int (*matmul_f64_f32_i8)(size_t, size_t, size_t, const double *, const float *, int8_t *, double) = _matmul_f64_f32_i8; +int (*matmul_f64_f32_u8)(size_t, size_t, size_t, const double *, const float *, uint8_t *, double) = _matmul_f64_f32_u8; +int (*matmul_f64_f64_f32)(size_t, size_t, size_t, const double *, const double *, float *, + double) = _matmul_f64_f64_f32; +int (*matmul_f64_f64_f64)(size_t, size_t, size_t, const double *, const double *, double *, + double) = _matmul_f64_f64_f64; +int (*matmul_f64_f64_i8)(size_t, size_t, size_t, const double *, const double *, int8_t *, double) = _matmul_f64_f64_i8; +int (*matmul_f64_f64_u8)(size_t, size_t, size_t, const double *, const double *, uint8_t *, + double) = _matmul_f64_f64_u8; +int (*matmul_f64_i8_f32)(size_t, size_t, size_t, const double *, const int8_t *, float *, double) = _matmul_f64_i8_f32; +int (*matmul_f64_i8_f64)(size_t, size_t, size_t, const double *, const int8_t *, double *, double) = _matmul_f64_i8_f64; +int (*matmul_f64_i8_i8)(size_t, size_t, size_t, const double *, const int8_t *, int8_t *, double) = _matmul_f64_i8_i8; +int (*matmul_f64_i8_u8)(size_t, size_t, size_t, const double *, const int8_t *, uint8_t *, double) = _matmul_f64_i8_u8; +int (*matmul_f64_u8_f32)(size_t, size_t, size_t, const double *, const uint8_t *, float *, double) = _matmul_f64_u8_f32; +int (*matmul_f64_u8_f64)(size_t, size_t, size_t, const double *, const uint8_t *, double *, + double) = _matmul_f64_u8_f64; +int (*matmul_f64_u8_i8)(size_t, size_t, size_t, const double *, const uint8_t *, int8_t *, double) = _matmul_f64_u8_i8; +int (*matmul_f64_u8_u8)(size_t, size_t, size_t, const double *, const uint8_t *, uint8_t *, double) = _matmul_f64_u8_u8; +int (*matmul_i8_f32_f32)(size_t, size_t, size_t, const int8_t *, const float *, float *, double) = _matmul_i8_f32_f32; +int (*matmul_i8_f32_f64)(size_t, size_t, size_t, const int8_t *, const float *, double *, double) = _matmul_i8_f32_f64; +int (*matmul_i8_f32_i8)(size_t, size_t, size_t, const int8_t *, const float *, int8_t *, double) = _matmul_i8_f32_i8; +int (*matmul_i8_f32_u8)(size_t, size_t, size_t, const int8_t *, const float *, uint8_t *, double) = _matmul_i8_f32_u8; +int (*matmul_i8_f64_f32)(size_t, size_t, size_t, const int8_t *, const double *, float *, double) = _matmul_i8_f64_f32; +int (*matmul_i8_f64_f64)(size_t, size_t, size_t, const int8_t *, const double *, double *, double) = _matmul_i8_f64_f64; +int (*matmul_i8_f64_i8)(size_t, size_t, size_t, const int8_t *, const double *, int8_t *, double) = _matmul_i8_f64_i8; +int (*matmul_i8_f64_u8)(size_t, size_t, size_t, const int8_t *, const double *, uint8_t *, double) = _matmul_i8_f64_u8; +int (*matmul_i8_i8_f32)(size_t, size_t, size_t, const int8_t *, const int8_t *, float *, double) = _matmul_i8_i8_f32; +int (*matmul_i8_i8_f64)(size_t, size_t, size_t, const int8_t *, const int8_t *, double *, double) = _matmul_i8_i8_f64; +int (*matmul_i8_i8_i8)(size_t, size_t, size_t, const int8_t *, const int8_t *, int8_t *, double) = _matmul_i8_i8_i8; +int (*matmul_i8_i8_u8)(size_t, size_t, size_t, const int8_t *, const int8_t *, uint8_t *, double) = _matmul_i8_i8_u8; +int (*matmul_i8_u8_f32)(size_t, size_t, size_t, const int8_t *, const uint8_t *, float *, double) = _matmul_i8_u8_f32; +int (*matmul_i8_u8_f64)(size_t, size_t, size_t, const int8_t *, const uint8_t *, double *, double) = _matmul_i8_u8_f64; +int (*matmul_i8_u8_i8)(size_t, size_t, size_t, const int8_t *, const uint8_t *, int8_t *, double) = _matmul_i8_u8_i8; +int (*matmul_i8_u8_u8)(size_t, size_t, size_t, const int8_t *, const uint8_t *, uint8_t *, double) = _matmul_i8_u8_u8; +int (*matmul_u8_f32_f32)(size_t, size_t, size_t, const uint8_t *, const float *, float *, double) = _matmul_u8_f32_f32; +int (*matmul_u8_f32_f64)(size_t, size_t, size_t, const uint8_t *, const float *, double *, double) = _matmul_u8_f32_f64; +int (*matmul_u8_f32_i8)(size_t, size_t, size_t, const uint8_t *, const float *, int8_t *, double) = _matmul_u8_f32_i8; +int (*matmul_u8_f32_u8)(size_t, size_t, size_t, const uint8_t *, const float *, uint8_t *, double) = _matmul_u8_f32_u8; +int (*matmul_u8_f64_f32)(size_t, size_t, size_t, const uint8_t *, const double *, float *, double) = _matmul_u8_f64_f32; +int (*matmul_u8_f64_f64)(size_t, size_t, size_t, const uint8_t *, const double *, double *, + double) = _matmul_u8_f64_f64; +int (*matmul_u8_f64_i8)(size_t, size_t, size_t, const uint8_t *, const double *, int8_t *, double) = _matmul_u8_f64_i8; +int (*matmul_u8_f64_u8)(size_t, size_t, size_t, const uint8_t *, const double *, uint8_t *, double) = _matmul_u8_f64_u8; +int (*matmul_u8_i8_f32)(size_t, size_t, size_t, const uint8_t *, const int8_t *, float *, double) = _matmul_u8_i8_f32; +int (*matmul_u8_i8_f64)(size_t, size_t, size_t, const uint8_t *, const int8_t *, double *, double) = _matmul_u8_i8_f64; +int (*matmul_u8_i8_i8)(size_t, size_t, size_t, const uint8_t *, const int8_t *, int8_t *, double) = _matmul_u8_i8_i8; +int (*matmul_u8_i8_u8)(size_t, size_t, size_t, const uint8_t *, const int8_t *, uint8_t *, double) = _matmul_u8_i8_u8; +int (*matmul_u8_u8_f32)(size_t, size_t, size_t, const uint8_t *, const uint8_t *, float *, double) = _matmul_u8_u8_f32; +int (*matmul_u8_u8_f64)(size_t, size_t, size_t, const uint8_t *, const uint8_t *, double *, double) = _matmul_u8_u8_f64; +int (*matmul_u8_u8_i8)(size_t, size_t, size_t, const uint8_t *, const uint8_t *, int8_t *, double) = _matmul_u8_u8_i8; +int (*matmul_u8_u8_u8)(size_t, size_t, size_t, const uint8_t *, const uint8_t *, uint8_t *, double) = _matmul_u8_u8_u8; diff --git a/src/matmul.h b/src/matmul.h @@ -56,70 +56,70 @@ typedef uint32_t matmul_feature_t; matmul_feature_t matmul_get_feature(void); const char *matmul_get_feature_name(matmul_feature_t feat); -int matmul_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const float *B, float *C, double scale); -int matmul_f32_f32_f64(size_t m, size_t n, size_t p, const float *A, const float *B, double *C, double scale); -int matmul_f32_f32_i8(size_t m, size_t n, size_t p, const float *A, const float *B, int8_t *C, double scale); -int matmul_f32_f32_u8(size_t m, size_t n, size_t p, const float *A, const float *B, uint8_t *C, double scale); -int matmul_f32_f64_f32(size_t m, size_t n, size_t p, const float *A, const double *B, float *C, double scale); -int matmul_f32_f64_f64(size_t m, size_t n, size_t p, const float *A, const double *B, double *C, double scale); -int matmul_f32_f64_i8(size_t m, size_t n, size_t p, const float *A, const double *B, int8_t *C, double scale); -int matmul_f32_f64_u8(size_t m, size_t n, size_t p, const float *A, const double *B, uint8_t *C, double scale); -int matmul_f32_i8_f32(size_t m, size_t n, size_t p, const float *A, const int8_t *B, float *C, double scale); -int matmul_f32_i8_f64(size_t m, size_t n, size_t p, const float *A, const int8_t *B, double *C, double scale); -int matmul_f32_i8_i8(size_t m, size_t n, size_t p, const float *A, const int8_t *B, int8_t *C, double scale); -int matmul_f32_i8_u8(size_t m, size_t n, size_t p, const float *A, const int8_t *B, uint8_t *C, double scale); -int matmul_f32_u8_f32(size_t m, size_t n, size_t p, const float *A, const uint8_t *B, float *C, double scale); -int matmul_f32_u8_f64(size_t m, size_t n, size_t p, const float *A, const uint8_t *B, double *C, double scale); -int matmul_f32_u8_i8(size_t m, size_t n, size_t p, const float *A, const uint8_t *B, int8_t *C, double scale); -int matmul_f32_u8_u8(size_t m, size_t n, size_t p, const float *A, const uint8_t *B, uint8_t *C, double scale); -int matmul_f64_f32_f32(size_t m, size_t n, size_t p, const double *A, const float *B, float *C, double scale); -int matmul_f64_f32_f64(size_t m, size_t n, size_t p, const double *A, const float *B, double *C, double scale); -int matmul_f64_f32_i8(size_t m, size_t n, size_t p, const double *A, const float *B, int8_t *C, double scale); -int matmul_f64_f32_u8(size_t m, size_t n, size_t p, const double *A, const float *B, uint8_t *C, double scale); -int matmul_f64_f64_f32(size_t m, size_t n, size_t p, const double *A, const double *B, float *C, double scale); -int matmul_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, const double *B, double *C, double scale); -int matmul_f64_f64_i8(size_t m, size_t n, size_t p, const double *A, const double *B, int8_t *C, double scale); -int matmul_f64_f64_u8(size_t m, size_t n, size_t p, const double *A, const double *B, uint8_t *C, double scale); -int matmul_f64_i8_f32(size_t m, size_t n, size_t p, const double *A, const int8_t *B, float *C, double scale); -int matmul_f64_i8_f64(size_t m, size_t n, size_t p, const double *A, const int8_t *B, double *C, double scale); -int matmul_f64_i8_i8(size_t m, size_t n, size_t p, const double *A, const int8_t *B, int8_t *C, double scale); -int matmul_f64_i8_u8(size_t m, size_t n, size_t p, const double *A, const int8_t *B, uint8_t *C, double scale); -int matmul_f64_u8_f32(size_t m, size_t n, size_t p, const double *A, const uint8_t *B, float *C, double scale); -int matmul_f64_u8_f64(size_t m, size_t n, size_t p, const double *A, const uint8_t *B, double *C, double scale); -int matmul_f64_u8_i8(size_t m, size_t n, size_t p, const double *A, const uint8_t *B, int8_t *C, double scale); -int matmul_f64_u8_u8(size_t m, size_t n, size_t p, const double *A, const uint8_t *B, uint8_t *C, double scale); -int matmul_i8_f32_f32(size_t m, size_t n, size_t p, const int8_t *A, const float *B, float *C, double scale); -int matmul_i8_f32_f64(size_t m, size_t n, size_t p, const int8_t *A, const float *B, double *C, double scale); -int matmul_i8_f32_i8(size_t m, size_t n, size_t p, const int8_t *A, const float *B, int8_t *C, double scale); -int matmul_i8_f32_u8(size_t m, size_t n, size_t p, const int8_t *A, const float *B, uint8_t *C, double scale); -int matmul_i8_f64_f32(size_t m, size_t n, size_t p, const int8_t *A, const double *B, float *C, double scale); -int matmul_i8_f64_f64(size_t m, size_t n, size_t p, const int8_t *A, const double *B, double *C, double scale); -int matmul_i8_f64_i8(size_t m, size_t n, size_t p, const int8_t *A, const double *B, int8_t *C, double scale); -int matmul_i8_f64_u8(size_t m, size_t n, size_t p, const int8_t *A, const double *B, uint8_t *C, double scale); -int matmul_i8_i8_f32(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, float *C, double scale); -int matmul_i8_i8_f64(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, double *C, double scale); -int matmul_i8_i8_i8(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, int8_t *C, double scale); -int matmul_i8_i8_u8(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, uint8_t *C, double scale); -int matmul_i8_u8_f32(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, float *C, double scale); -int matmul_i8_u8_f64(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, double *C, double scale); -int matmul_i8_u8_i8(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, int8_t *C, double scale); -int matmul_i8_u8_u8(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, uint8_t *C, double scale); -int matmul_u8_f32_f32(size_t m, size_t n, size_t p, const uint8_t *A, const float *B, float *C, double scale); -int matmul_u8_f32_f64(size_t m, size_t n, size_t p, const uint8_t *A, const float *B, double *C, double scale); -int matmul_u8_f32_i8(size_t m, size_t n, size_t p, const uint8_t *A, const float *B, int8_t *C, double scale); -int matmul_u8_f32_u8(size_t m, size_t n, size_t p, const uint8_t *A, const float *B, uint8_t *C, double scale); -int matmul_u8_f64_f32(size_t m, size_t n, size_t p, const uint8_t *A, const double *B, float *C, double scale); -int matmul_u8_f64_f64(size_t m, size_t n, size_t p, const uint8_t *A, const double *B, double *C, double scale); -int matmul_u8_f64_i8(size_t m, size_t n, size_t p, const uint8_t *A, const double *B, int8_t *C, double scale); -int matmul_u8_f64_u8(size_t m, size_t n, size_t p, const uint8_t *A, const double *B, uint8_t *C, double scale); -int matmul_u8_i8_f32(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, float *C, double scale); -int matmul_u8_i8_f64(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, double *C, double scale); -int matmul_u8_i8_i8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, int8_t *C, double scale); -int matmul_u8_i8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, uint8_t *C, double scale); -int matmul_u8_u8_f32(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, float *C, double scale); -int matmul_u8_u8_f64(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, double *C, double scale); -int matmul_u8_u8_i8(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, int8_t *C, double scale); -int matmul_u8_u8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, uint8_t *C, double scale); +extern int (*matmul_f32_f32_f32)(size_t, size_t, size_t, const float *, const float *, float *, double); +extern int (*matmul_f32_f32_f64)(size_t, size_t, size_t, const float *, const float *, double *, double); +extern int (*matmul_f32_f32_i8)(size_t, size_t, size_t, const float *, const float *, int8_t *, double); +extern int (*matmul_f32_f32_u8)(size_t, size_t, size_t, const float *, const float *, uint8_t *, double); +extern int (*matmul_f32_f64_f32)(size_t, size_t, size_t, const float *, const double *, float *, double); +extern int (*matmul_f32_f64_f64)(size_t, size_t, size_t, const float *, const double *, double *, double); +extern int (*matmul_f32_f64_i8)(size_t, size_t, size_t, const float *, const double *, int8_t *, double); +extern int (*matmul_f32_f64_u8)(size_t, size_t, size_t, const float *, const double *, uint8_t *, double); +extern int (*matmul_f32_i8_f32)(size_t, size_t, size_t, const float *, const int8_t *, float *, double); +extern int (*matmul_f32_i8_f64)(size_t, size_t, size_t, const float *, const int8_t *, double *, double); +extern int (*matmul_f32_i8_i8)(size_t, size_t, size_t, const float *, const int8_t *, int8_t *, double); +extern int (*matmul_f32_i8_u8)(size_t, size_t, size_t, const float *, const int8_t *, uint8_t *, double); +extern int (*matmul_f32_u8_f32)(size_t, size_t, size_t, const float *, const uint8_t *, float *, double); +extern int (*matmul_f32_u8_f64)(size_t, size_t, size_t, const float *, const uint8_t *, double *, double); +extern int (*matmul_f32_u8_i8)(size_t, size_t, size_t, const float *, const uint8_t *, int8_t *, double); +extern int (*matmul_f32_u8_u8)(size_t, size_t, size_t, const float *, const uint8_t *, uint8_t *, double); +extern int (*matmul_f64_f32_f32)(size_t, size_t, size_t, const double *, const float *, float *, double); +extern int (*matmul_f64_f32_f64)(size_t, size_t, size_t, const double *, const float *, double *, double); +extern int (*matmul_f64_f32_i8)(size_t, size_t, size_t, const double *, const float *, int8_t *, double); +extern int (*matmul_f64_f32_u8)(size_t, size_t, size_t, const double *, const float *, uint8_t *, double); +extern int (*matmul_f64_f64_f32)(size_t, size_t, size_t, const double *, const double *, float *, double); +extern int (*matmul_f64_f64_f64)(size_t, size_t, size_t, const double *, const double *, double *, double); +extern int (*matmul_f64_f64_i8)(size_t, size_t, size_t, const double *, const double *, int8_t *, double); +extern int (*matmul_f64_f64_u8)(size_t, size_t, size_t, const double *, const double *, uint8_t *, double); +extern int (*matmul_f64_i8_f32)(size_t, size_t, size_t, const double *, const int8_t *, float *, double); +extern int (*matmul_f64_i8_f64)(size_t, size_t, size_t, const double *, const int8_t *, double *, double); +extern int (*matmul_f64_i8_i8)(size_t, size_t, size_t, const double *, const int8_t *, int8_t *, double); +extern int (*matmul_f64_i8_u8)(size_t, size_t, size_t, const double *, const int8_t *, uint8_t *, double); +extern int (*matmul_f64_u8_f32)(size_t, size_t, size_t, const double *, const uint8_t *, float *, double); +extern int (*matmul_f64_u8_f64)(size_t, size_t, size_t, const double *, const uint8_t *, double *, double); +extern int (*matmul_f64_u8_i8)(size_t, size_t, size_t, const double *, const uint8_t *, int8_t *, double); +extern int (*matmul_f64_u8_u8)(size_t, size_t, size_t, const double *, const uint8_t *, uint8_t *, double); +extern int (*matmul_i8_f32_f32)(size_t, size_t, size_t, const int8_t *, const float *, float *, double); +extern int (*matmul_i8_f32_f64)(size_t, size_t, size_t, const int8_t *, const float *, double *, double); +extern int (*matmul_i8_f32_i8)(size_t, size_t, size_t, const int8_t *, const float *, int8_t *, double); +extern int (*matmul_i8_f32_u8)(size_t, size_t, size_t, const int8_t *, const float *, uint8_t *, double); +extern int (*matmul_i8_f64_f32)(size_t, size_t, size_t, const int8_t *, const double *, float *, double); +extern int (*matmul_i8_f64_f64)(size_t, size_t, size_t, const int8_t *, const double *, double *, double); +extern int (*matmul_i8_f64_i8)(size_t, size_t, size_t, const int8_t *, const double *, int8_t *, double); +extern int (*matmul_i8_f64_u8)(size_t, size_t, size_t, const int8_t *, const double *, uint8_t *, double); +extern int (*matmul_i8_i8_f32)(size_t, size_t, size_t, const int8_t *, const int8_t *, float *, double); +extern int (*matmul_i8_i8_f64)(size_t, size_t, size_t, const int8_t *, const int8_t *, double *, double); +extern int (*matmul_i8_i8_i8)(size_t, size_t, size_t, const int8_t *, const int8_t *, int8_t *, double); +extern int (*matmul_i8_i8_u8)(size_t, size_t, size_t, const int8_t *, const int8_t *, uint8_t *, double); +extern int (*matmul_i8_u8_f32)(size_t, size_t, size_t, const int8_t *, const uint8_t *, float *, double); +extern int (*matmul_i8_u8_f64)(size_t, size_t, size_t, const int8_t *, const uint8_t *, double *, double); +extern int (*matmul_i8_u8_i8)(size_t, size_t, size_t, const int8_t *, const uint8_t *, int8_t *, double); +extern int (*matmul_i8_u8_u8)(size_t, size_t, size_t, const int8_t *, const uint8_t *, uint8_t *, double); +extern int (*matmul_u8_f32_f32)(size_t, size_t, size_t, const uint8_t *, const float *, float *, double); +extern int (*matmul_u8_f32_f64)(size_t, size_t, size_t, const uint8_t *, const float *, double *, double); +extern int (*matmul_u8_f32_i8)(size_t, size_t, size_t, const uint8_t *, const float *, int8_t *, double); +extern int (*matmul_u8_f32_u8)(size_t, size_t, size_t, const uint8_t *, const float *, uint8_t *, double); +extern int (*matmul_u8_f64_f32)(size_t, size_t, size_t, const uint8_t *, const double *, float *, double); +extern int (*matmul_u8_f64_f64)(size_t, size_t, size_t, const uint8_t *, const double *, double *, double); +extern int (*matmul_u8_f64_i8)(size_t, size_t, size_t, const uint8_t *, const double *, int8_t *, double); +extern int (*matmul_u8_f64_u8)(size_t, size_t, size_t, const uint8_t *, const double *, uint8_t *, double); +extern int (*matmul_u8_i8_f32)(size_t, size_t, size_t, const uint8_t *, const int8_t *, float *, double); +extern int (*matmul_u8_i8_f64)(size_t, size_t, size_t, const uint8_t *, const int8_t *, double *, double); +extern int (*matmul_u8_i8_i8)(size_t, size_t, size_t, const uint8_t *, const int8_t *, int8_t *, double); +extern int (*matmul_u8_i8_u8)(size_t, size_t, size_t, const uint8_t *, const int8_t *, uint8_t *, double); +extern int (*matmul_u8_u8_f32)(size_t, size_t, size_t, const uint8_t *, const uint8_t *, float *, double); +extern int (*matmul_u8_u8_f64)(size_t, size_t, size_t, const uint8_t *, const uint8_t *, double *, double); +extern int (*matmul_u8_u8_i8)(size_t, size_t, size_t, const uint8_t *, const uint8_t *, int8_t *, double); +extern int (*matmul_u8_u8_u8)(size_t, size_t, size_t, const uint8_t *, const uint8_t *, uint8_t *, double); int matmul_scalar_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const float *B, float *C, double scale); int matmul_scalar_f32_f32_f64(size_t m, size_t n, size_t p, const float *A, const float *B, double *C, double scale); @@ -209,6 +209,44 @@ int matmul_avx512_f64_f64_f32(size_t m, size_t n, size_t p, const double *A, con int matmul_avx512_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, const double *B, double *C, double scale); #endif +#ifdef __AVX2__ +int matmul_avx2_i8_i8_f32(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, float *C, double scale); +int matmul_avx2_i8_i8_f64(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, double *C, double scale); +int matmul_avx2_i8_i8_i8(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, int8_t *C, double scale); +int matmul_avx2_i8_i8_u8(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, uint8_t *C, double scale); +int matmul_avx2_i8_u8_f32(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, float *C, double scale); +int matmul_avx2_i8_u8_f64(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, double *C, double scale); +int matmul_avx2_i8_u8_i8(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, int8_t *C, double scale); +int matmul_avx2_i8_u8_u8(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, uint8_t *C, double scale); +int matmul_avx2_u8_i8_f32(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, float *C, double scale); +int matmul_avx2_u8_i8_f64(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, double *C, double scale); +int matmul_avx2_u8_i8_i8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, int8_t *C, double scale); +int matmul_avx2_u8_i8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, uint8_t *C, double scale); +int matmul_avx2_u8_u8_f32(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, float *C, double scale); +int matmul_avx2_u8_u8_f64(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, double *C, double scale); +int matmul_avx2_u8_u8_i8(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, int8_t *C, double scale); +int matmul_avx2_u8_u8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, uint8_t *C, double scale); +#endif + +#ifdef __AVX512F__ +int matmul_avx512_i8_i8_f32(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, float *C, double scale); +int matmul_avx512_i8_i8_f64(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, double *C, double scale); +int matmul_avx512_i8_i8_i8(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, int8_t *C, double scale); +int matmul_avx512_i8_i8_u8(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, uint8_t *C, double scale); +int matmul_avx512_i8_u8_f32(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, float *C, double scale); +int matmul_avx512_i8_u8_f64(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, double *C, double scale); +int matmul_avx512_i8_u8_i8(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, int8_t *C, double scale); +int matmul_avx512_i8_u8_u8(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, uint8_t *C, double scale); +int matmul_avx512_u8_i8_f32(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, float *C, double scale); +int matmul_avx512_u8_i8_f64(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, double *C, double scale); +int matmul_avx512_u8_i8_i8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, int8_t *C, double scale); +int matmul_avx512_u8_i8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, uint8_t *C, double scale); +int matmul_avx512_u8_u8_f32(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, float *C, double scale); +int matmul_avx512_u8_u8_f64(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, double *C, double scale); +int matmul_avx512_u8_u8_i8(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, int8_t *C, double scale); +int matmul_avx512_u8_u8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, uint8_t *C, double scale); +#endif + #define matmul(m, n, p, A, B, C, scale) \ _Generic((A), \ float: _Generic((B), \