commit 5fca7db716571e8b1188f10d3cf38b0fa89f33b6
parent adb428832e725526e51335694a2917f4d0cc3091
Author: finwo <finwo@pm.me>
Date: Thu, 16 Apr 2026 23:30:08 +0200
Re-add vnni accel
Diffstat:
| M | src/matmul.c | | | 2270 | +++++++++++++++++++++++++++++++++++++++++++++++++++++++++---------------------- |
| M | src/matmul.h | | | 166 | ++++++++++++++++++++++++++++++++++++++++++++++++------------------------------- |
2 files changed, 1747 insertions(+), 689 deletions(-)
diff --git a/src/matmul.c b/src/matmul.c
@@ -55,7 +55,13 @@ static void init_feature(void) {
if (__builtin_cpu_supports("avx2")) g_feature |= MATMUL_FLAG_AVX2;
#endif
#ifdef __AVX512F__
- if (__builtin_cpu_supports("avx512f")) g_feature |= MATMUL_FLAG_AVX512;
+ if (__builtin_cpu_supports("avx512f")) {
+ g_feature |= MATMUL_FLAG_AVX512;
+ if (__builtin_cpu_supports("avx512vnni")) g_feature |= MATMUL_FLAG_AVX512_VNNI;
+ }
+#endif
+#ifdef __AVX2__
+ if (__builtin_cpu_supports("avx2") && __builtin_cpu_supports("avxvnni")) g_feature |= MATMUL_FLAG_AVXVNNI;
#endif
}
@@ -1433,952 +1439,1966 @@ int matmul_avx512_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, con
}
#endif
-int matmul_scalar_u8_u8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, uint8_t *C, double scale) {
+#ifdef __AVX2__
+static inline int32_t reduce_add_i32x8(__m256i v) {
+ __m128i low = _mm256_extracti128_si256(v, 0);
+ __m128i high = _mm256_extracti128_si256(v, 1);
+ __m128i sum = _mm_add_epi32(low, high);
+ sum = _mm_hadd_epi32(sum, sum);
+ sum = _mm_hadd_epi32(sum, sum);
+ return _mm_cvtsi128_si32(sum);
+}
+
+int matmul_avx2_i8_i8_f32(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, float *C, double scale) {
#ifdef _OPENMP
#pragma omp parallel for schedule(static)
#endif
for (size_t i = 0; i < m; i++) {
for (size_t j = 0; j < p; j++) {
- int sum = 0;
- for (size_t k = 0; k < n; k++) {
- sum += (int)A[i * n + k] * (int)B[k * p + j];
+ __m256i sum_vec = _mm256_setzero_si256();
+ size_t k = 0;
+ for (; k + 31 < n; k += 32) {
+ __m256i a_lo = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k]));
+ __m256i a_hi = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16]));
+ __m256i b_lo = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j]));
+ __m256i b_hi = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16]));
+ __m256i mul_lo = _mm256_madd_epi16(a_lo, b_lo);
+ __m256i mul_hi = _mm256_madd_epi16(a_hi, b_hi);
+ sum_vec = _mm256_add_epi32(sum_vec, mul_lo);
+ sum_vec = _mm256_add_epi32(sum_vec, mul_hi);
}
- if (scale != 0 && scale != 1) sum = (int)(sum / scale);
- if (sum > 255)
- sum = 255;
- else if (sum < 0)
- sum = 0;
- C[i * p + j] = (uint8_t)sum;
+ int s = reduce_add_i32x8(sum_vec);
+ for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j];
+ if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale);
+ C[i * p + j] = (float)s;
}
}
return 0;
}
-static int _matmul_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const float *B, float *C, double scale) {
- static int (*impl)(size_t, size_t, size_t, const float *, const float *, float *, double) = NULL;
- static int initialized = 0;
- if (!initialized) {
- matmul_feature_t feat = matmul_get_feature();
-#ifdef __AVX512F__
- if (feat & MATMUL_FLAG_AVX512)
- impl = matmul_avx512_f32_f32_f32;
- else
-#endif
-#ifdef __AVX2__
- if (feat & MATMUL_FLAG_AVX2)
- impl = matmul_avx2_f32_f32_f32;
- else
+int matmul_avx2_i8_i8_f64(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, double *C, double scale) {
+#ifdef _OPENMP
+#pragma omp parallel for schedule(static)
#endif
- impl = matmul_scalar_f32_f32_f32;
- initialized = 1;
+ for (size_t i = 0; i < m; i++) {
+ for (size_t j = 0; j < p; j++) {
+ __m256i sum_vec = _mm256_setzero_si256();
+ size_t k = 0;
+ for (; k + 31 < n; k += 32) {
+ __m256i a_lo = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k]));
+ __m256i a_hi = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16]));
+ __m256i b_lo = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j]));
+ __m256i b_hi = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16]));
+ __m256i mul_lo = _mm256_madd_epi16(a_lo, b_lo);
+ __m256i mul_hi = _mm256_madd_epi16(a_hi, b_hi);
+ sum_vec = _mm256_add_epi32(sum_vec, mul_lo);
+ sum_vec = _mm256_add_epi32(sum_vec, mul_hi);
+ }
+ int32_t sum[8];
+ _mm256_storeu_si256((__m256i *)sum, sum_vec);
+ int s = reduce_add_i32x8(sum_vec);
+ for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j];
+ if (scale != 0 && scale != 1) s = (int)((double)s / scale);
+ C[i * p + j] = (double)s;
+ }
}
- return impl(m, n, p, A, B, C, scale);
+ return 0;
}
-static int _matmul_f32_f32_f64(size_t m, size_t n, size_t p, const float *A, const float *B, double *C, double scale) {
- static int (*impl)(size_t, size_t, size_t, const float *, const float *, double *, double) = NULL;
- static int initialized = 0;
- if (!initialized) {
- matmul_feature_t feat = matmul_get_feature();
-#ifdef __AVX512F__
- if (feat & MATMUL_FLAG_AVX512)
- impl = matmul_avx512_f32_f32_f64;
- else
-#endif
-#ifdef __AVX2__
- if (feat & MATMUL_FLAG_AVX2)
- impl = matmul_avx2_f32_f32_f64;
- else
+int matmul_avx2_i8_i8_i8(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, int8_t *C, double scale) {
+#ifdef _OPENMP
+#pragma omp parallel for schedule(static)
#endif
- impl = matmul_scalar_f32_f32_f64;
- initialized = 1;
+ for (size_t i = 0; i < m; i++) {
+ for (size_t j = 0; j < p; j++) {
+ __m256i sum_vec = _mm256_setzero_si256();
+ size_t k = 0;
+ for (; k + 31 < n; k += 32) {
+ __m256i a_lo = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k]));
+ __m256i a_hi = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16]));
+ __m256i b_lo = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j]));
+ __m256i b_hi = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16]));
+ __m256i mul_lo = _mm256_madd_epi16(a_lo, b_lo);
+ __m256i mul_hi = _mm256_madd_epi16(a_hi, b_hi);
+ sum_vec = _mm256_add_epi32(sum_vec, mul_lo);
+ sum_vec = _mm256_add_epi32(sum_vec, mul_hi);
+ }
+ int32_t sum[8];
+ _mm256_storeu_si256((__m256i *)sum, sum_vec);
+ int s = reduce_add_i32x8(sum_vec);
+ for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j];
+ if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale);
+ if (s > 127)
+ s = 127;
+ else if (s < -128)
+ s = -128;
+ C[i * p + j] = (int8_t)s;
+ }
}
- return impl(m, n, p, A, B, C, scale);
+ return 0;
}
-static int _matmul_f32_f32_i8(size_t m, size_t n, size_t p, const float *A, const float *B, int8_t *C, double scale) {
- static int (*impl)(size_t, size_t, size_t, const float *, const float *, int8_t *, double) = NULL;
- static int initialized = 0;
- if (!initialized) {
- impl = matmul_scalar_f32_f32_i8;
- initialized = 1;
+int matmul_avx2_i8_i8_u8(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, uint8_t *C, double scale) {
+#ifdef _OPENMP
+#pragma omp parallel for schedule(static)
+#endif
+ for (size_t i = 0; i < m; i++) {
+ for (size_t j = 0; j < p; j++) {
+ __m256i sum_vec = _mm256_setzero_si256();
+ size_t k = 0;
+ for (; k + 31 < n; k += 32) {
+ __m256i a_lo = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k]));
+ __m256i a_hi = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16]));
+ __m256i b_lo = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j]));
+ __m256i b_hi = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16]));
+ __m256i mul_lo = _mm256_madd_epi16(a_lo, b_lo);
+ __m256i mul_hi = _mm256_madd_epi16(a_hi, b_hi);
+ sum_vec = _mm256_add_epi32(sum_vec, mul_lo);
+ sum_vec = _mm256_add_epi32(sum_vec, mul_hi);
+ }
+ int32_t sum[8];
+ _mm256_storeu_si256((__m256i *)sum, sum_vec);
+ int s = reduce_add_i32x8(sum_vec);
+ for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j];
+ if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale);
+ if (s > 255)
+ s = 255;
+ else if (s < 0)
+ s = 0;
+ C[i * p + j] = (uint8_t)s;
+ }
}
- return impl(m, n, p, A, B, C, scale);
+ return 0;
}
-static int _matmul_f32_f32_u8(size_t m, size_t n, size_t p, const float *A, const float *B, uint8_t *C, double scale) {
- static int (*impl)(size_t, size_t, size_t, const float *, const float *, uint8_t *, double) = NULL;
- static int initialized = 0;
- if (!initialized) {
- impl = matmul_scalar_f32_f32_u8;
- initialized = 1;
+int matmul_avx2_i8_u8_f32(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, float *C, double scale) {
+#ifdef _OPENMP
+#pragma omp parallel for schedule(static)
+#endif
+ for (size_t i = 0; i < m; i++) {
+ for (size_t j = 0; j < p; j++) {
+ __m256i sum_vec = _mm256_setzero_si256();
+ size_t k = 0;
+ for (; k + 31 < n; k += 32) {
+ __m256i a_lo = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k]));
+ __m256i a_hi = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16]));
+ __m256i b_lo = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j]));
+ __m256i b_hi = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16]));
+ __m256i mul_lo = _mm256_madd_epi16(a_lo, b_lo);
+ __m256i mul_hi = _mm256_madd_epi16(a_hi, b_hi);
+ sum_vec = _mm256_add_epi32(sum_vec, mul_lo);
+ sum_vec = _mm256_add_epi32(sum_vec, mul_hi);
+ }
+ int s = reduce_add_i32x8(sum_vec);
+ for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j];
+ if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale);
+ C[i * p + j] = (float)s;
+ }
}
- return impl(m, n, p, A, B, C, scale);
+ return 0;
}
-static int _matmul_f32_f64_f32(size_t m, size_t n, size_t p, const float *A, const double *B, float *C, double scale) {
- static int (*impl)(size_t, size_t, size_t, const float *, const double *, float *, double) = NULL;
- static int initialized = 0;
- if (!initialized) {
- matmul_feature_t feat = matmul_get_feature();
-#ifdef __AVX512F__
- if (feat & MATMUL_FLAG_AVX512)
- impl = matmul_avx512_f32_f64_f32;
- else
-#endif
-#ifdef __AVX2__
- if (feat & MATMUL_FLAG_AVX2)
- impl = matmul_avx2_f32_f64_f32;
- else
+int matmul_avx2_i8_u8_f64(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, double *C, double scale) {
+#ifdef _OPENMP
+#pragma omp parallel for schedule(static)
#endif
- impl = matmul_scalar_f32_f64_f32;
- initialized = 1;
+ for (size_t i = 0; i < m; i++) {
+ for (size_t j = 0; j < p; j++) {
+ __m256i sum_vec = _mm256_setzero_si256();
+ size_t k = 0;
+ for (; k + 31 < n; k += 32) {
+ __m256i a_lo = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k]));
+ __m256i a_hi = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16]));
+ __m256i b_lo = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j]));
+ __m256i b_hi = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16]));
+ __m256i mul_lo = _mm256_madd_epi16(a_lo, b_lo);
+ __m256i mul_hi = _mm256_madd_epi16(a_hi, b_hi);
+ sum_vec = _mm256_add_epi32(sum_vec, mul_lo);
+ sum_vec = _mm256_add_epi32(sum_vec, mul_hi);
+ }
+ int s = reduce_add_i32x8(sum_vec);
+ for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j];
+ if (scale != 0 && scale != 1) s = (int)((double)s / scale);
+ C[i * p + j] = (double)s;
+ }
}
- return impl(m, n, p, A, B, C, scale);
+ return 0;
}
-static int _matmul_f32_f64_f64(size_t m, size_t n, size_t p, const float *A, const double *B, double *C, double scale) {
- static int (*impl)(size_t, size_t, size_t, const float *, const double *, double *, double) = NULL;
- static int initialized = 0;
- if (!initialized) {
- matmul_feature_t feat = matmul_get_feature();
-#ifdef __AVX512F__
- if (feat & MATMUL_FLAG_AVX512)
- impl = matmul_avx512_f32_f64_f64;
- else
-#endif
-#ifdef __AVX2__
- if (feat & MATMUL_FLAG_AVX2)
- impl = matmul_avx2_f32_f64_f64;
- else
+int matmul_avx2_i8_u8_i8(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, int8_t *C, double scale) {
+#ifdef _OPENMP
+#pragma omp parallel for schedule(static)
#endif
- impl = matmul_scalar_f32_f64_f64;
- initialized = 1;
+ for (size_t i = 0; i < m; i++) {
+ for (size_t j = 0; j < p; j++) {
+ __m256i sum_vec = _mm256_setzero_si256();
+ size_t k = 0;
+ for (; k + 31 < n; k += 32) {
+ __m256i a_lo = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k]));
+ __m256i a_hi = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16]));
+ __m256i b_lo = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j]));
+ __m256i b_hi = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16]));
+ __m256i mul_lo = _mm256_madd_epi16(a_lo, b_lo);
+ __m256i mul_hi = _mm256_madd_epi16(a_hi, b_hi);
+ sum_vec = _mm256_add_epi32(sum_vec, mul_lo);
+ sum_vec = _mm256_add_epi32(sum_vec, mul_hi);
+ }
+ int s = reduce_add_i32x8(sum_vec);
+ for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j];
+ if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale);
+ if (s > 127)
+ s = 127;
+ else if (s < -128)
+ s = -128;
+ C[i * p + j] = (int8_t)s;
+ }
}
- return impl(m, n, p, A, B, C, scale);
+ return 0;
}
-static int _matmul_f32_f64_i8(size_t m, size_t n, size_t p, const float *A, const double *B, int8_t *C, double scale) {
- static int (*impl)(size_t, size_t, size_t, const float *, const double *, int8_t *, double) = NULL;
- static int initialized = 0;
- if (!initialized) {
- impl = matmul_scalar_f32_f64_i8;
- initialized = 1;
+int matmul_avx2_i8_u8_u8(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, uint8_t *C, double scale) {
+#ifdef _OPENMP
+#pragma omp parallel for schedule(static)
+#endif
+ for (size_t i = 0; i < m; i++) {
+ for (size_t j = 0; j < p; j++) {
+ __m256i sum_vec = _mm256_setzero_si256();
+ size_t k = 0;
+ for (; k + 31 < n; k += 32) {
+ __m256i a_lo = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k]));
+ __m256i a_hi = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16]));
+ __m256i b_lo = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j]));
+ __m256i b_hi = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16]));
+ __m256i mul_lo = _mm256_madd_epi16(a_lo, b_lo);
+ __m256i mul_hi = _mm256_madd_epi16(a_hi, b_hi);
+ sum_vec = _mm256_add_epi32(sum_vec, mul_lo);
+ sum_vec = _mm256_add_epi32(sum_vec, mul_hi);
+ }
+ int s = reduce_add_i32x8(sum_vec);
+ for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j];
+ if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale);
+ if (s > 255)
+ s = 255;
+ else if (s < 0)
+ s = 0;
+ C[i * p + j] = (uint8_t)s;
+ }
}
- return impl(m, n, p, A, B, C, scale);
+ return 0;
}
-static int _matmul_f32_f64_u8(size_t m, size_t n, size_t p, const float *A, const double *B, uint8_t *C, double scale) {
- static int (*impl)(size_t, size_t, size_t, const float *, const double *, uint8_t *, double) = NULL;
- static int initialized = 0;
- if (!initialized) {
- impl = matmul_scalar_f32_f64_u8;
- initialized = 1;
+int matmul_avx2_u8_i8_f32(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, float *C, double scale) {
+#ifdef _OPENMP
+#pragma omp parallel for schedule(static)
+#endif
+ for (size_t i = 0; i < m; i++) {
+ for (size_t j = 0; j < p; j++) {
+ __m256i sum_vec = _mm256_setzero_si256();
+ size_t k = 0;
+ for (; k + 31 < n; k += 32) {
+ __m256i a_lo = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k]));
+ __m256i a_hi = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16]));
+ __m256i b_lo = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j]));
+ __m256i b_hi = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16]));
+ __m256i mul_lo = _mm256_madd_epi16(a_lo, b_lo);
+ __m256i mul_hi = _mm256_madd_epi16(a_hi, b_hi);
+ sum_vec = _mm256_add_epi32(sum_vec, mul_lo);
+ sum_vec = _mm256_add_epi32(sum_vec, mul_hi);
+ }
+ int s = reduce_add_i32x8(sum_vec);
+ for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j];
+ if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale);
+ C[i * p + j] = (float)s;
+ }
}
- return impl(m, n, p, A, B, C, scale);
+ return 0;
}
-static int _matmul_f32_i8_f32(size_t m, size_t n, size_t p, const float *A, const int8_t *B, float *C, double scale) {
- static int (*impl)(size_t, size_t, size_t, const float *, const int8_t *, float *, double) = NULL;
- static int initialized = 0;
- if (!initialized) {
- impl = matmul_scalar_f32_i8_f32;
- initialized = 1;
+int matmul_avx2_u8_i8_f64(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, double *C, double scale) {
+#ifdef _OPENMP
+#pragma omp parallel for schedule(static)
+#endif
+ for (size_t i = 0; i < m; i++) {
+ for (size_t j = 0; j < p; j++) {
+ __m256i sum_vec = _mm256_setzero_si256();
+ size_t k = 0;
+ for (; k + 31 < n; k += 32) {
+ __m256i a_lo = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k]));
+ __m256i a_hi = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16]));
+ __m256i b_lo = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j]));
+ __m256i b_hi = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16]));
+ __m256i mul_lo = _mm256_madd_epi16(a_lo, b_lo);
+ __m256i mul_hi = _mm256_madd_epi16(a_hi, b_hi);
+ sum_vec = _mm256_add_epi32(sum_vec, mul_lo);
+ sum_vec = _mm256_add_epi32(sum_vec, mul_hi);
+ }
+ int s = reduce_add_i32x8(sum_vec);
+ for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j];
+ if (scale != 0 && scale != 1) s = (int)((double)s / scale);
+ C[i * p + j] = (double)s;
+ }
}
- return impl(m, n, p, A, B, C, scale);
+ return 0;
}
-static int _matmul_f32_i8_f64(size_t m, size_t n, size_t p, const float *A, const int8_t *B, double *C, double scale) {
- static int (*impl)(size_t, size_t, size_t, const float *, const int8_t *, double *, double) = NULL;
- static int initialized = 0;
- if (!initialized) {
- impl = matmul_scalar_f32_i8_f64;
- initialized = 1;
+int matmul_avx2_u8_i8_i8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, int8_t *C, double scale) {
+#ifdef _OPENMP
+#pragma omp parallel for schedule(static)
+#endif
+ for (size_t i = 0; i < m; i++) {
+ for (size_t j = 0; j < p; j++) {
+ __m256i sum_vec = _mm256_setzero_si256();
+ size_t k = 0;
+ for (; k + 31 < n; k += 32) {
+ __m256i a_lo = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k]));
+ __m256i a_hi = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16]));
+ __m256i b_lo = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j]));
+ __m256i b_hi = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16]));
+ __m256i mul_lo = _mm256_madd_epi16(a_lo, b_lo);
+ __m256i mul_hi = _mm256_madd_epi16(a_hi, b_hi);
+ sum_vec = _mm256_add_epi32(sum_vec, mul_lo);
+ sum_vec = _mm256_add_epi32(sum_vec, mul_hi);
+ }
+ int s = reduce_add_i32x8(sum_vec);
+ for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j];
+ if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale);
+ if (s > 127)
+ s = 127;
+ else if (s < -128)
+ s = -128;
+ C[i * p + j] = (int8_t)s;
+ }
}
- return impl(m, n, p, A, B, C, scale);
+ return 0;
}
-static int _matmul_f32_i8_i8(size_t m, size_t n, size_t p, const float *A, const int8_t *B, int8_t *C, double scale) {
- static int (*impl)(size_t, size_t, size_t, const float *, const int8_t *, int8_t *, double) = NULL;
- static int initialized = 0;
- if (!initialized) {
- impl = matmul_scalar_f32_i8_i8;
- initialized = 1;
+int matmul_avx2_u8_i8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, uint8_t *C, double scale) {
+#ifdef _OPENMP
+#pragma omp parallel for schedule(static)
+#endif
+ for (size_t i = 0; i < m; i++) {
+ for (size_t j = 0; j < p; j++) {
+ __m256i sum_vec = _mm256_setzero_si256();
+ size_t k = 0;
+ for (; k + 31 < n; k += 32) {
+ __m256i a_lo = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k]));
+ __m256i a_hi = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16]));
+ __m256i b_lo = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j]));
+ __m256i b_hi = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16]));
+ __m256i mul_lo = _mm256_madd_epi16(a_lo, b_lo);
+ __m256i mul_hi = _mm256_madd_epi16(a_hi, b_hi);
+ sum_vec = _mm256_add_epi32(sum_vec, mul_lo);
+ sum_vec = _mm256_add_epi32(sum_vec, mul_hi);
+ }
+ int s = reduce_add_i32x8(sum_vec);
+ for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j];
+ if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale);
+ if (s > 255)
+ s = 255;
+ else if (s < 0)
+ s = 0;
+ C[i * p + j] = (uint8_t)s;
+ }
}
- return impl(m, n, p, A, B, C, scale);
+ return 0;
}
-static int _matmul_f32_i8_u8(size_t m, size_t n, size_t p, const float *A, const int8_t *B, uint8_t *C, double scale) {
- static int (*impl)(size_t, size_t, size_t, const float *, const int8_t *, uint8_t *, double) = NULL;
- static int initialized = 0;
- if (!initialized) {
- impl = matmul_scalar_f32_i8_u8;
- initialized = 1;
+int matmul_avx2_u8_u8_f32(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, float *C, double scale) {
+#ifdef _OPENMP
+#pragma omp parallel for schedule(static)
+#endif
+ for (size_t i = 0; i < m; i++) {
+ for (size_t j = 0; j < p; j++) {
+ __m256i sum_vec = _mm256_setzero_si256();
+ size_t k = 0;
+ for (; k + 31 < n; k += 32) {
+ __m256i a_lo = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k]));
+ __m256i a_hi = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16]));
+ __m256i b_lo = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j]));
+ __m256i b_hi = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16]));
+ __m256i mul_lo = _mm256_madd_epi16(a_lo, b_lo);
+ __m256i mul_hi = _mm256_madd_epi16(a_hi, b_hi);
+ sum_vec = _mm256_add_epi32(sum_vec, mul_lo);
+ sum_vec = _mm256_add_epi32(sum_vec, mul_hi);
+ }
+ int s = reduce_add_i32x8(sum_vec);
+ for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j];
+ if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale);
+ C[i * p + j] = (float)s;
+ }
}
- return impl(m, n, p, A, B, C, scale);
+ return 0;
}
-static int _matmul_f32_u8_f32(size_t m, size_t n, size_t p, const float *A, const uint8_t *B, float *C, double scale) {
- static int (*impl)(size_t, size_t, size_t, const float *, const uint8_t *, float *, double) = NULL;
- static int initialized = 0;
- if (!initialized) {
- impl = matmul_scalar_f32_u8_f32;
- initialized = 1;
+int matmul_avx2_u8_u8_f64(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, double *C, double scale) {
+#ifdef _OPENMP
+#pragma omp parallel for schedule(static)
+#endif
+ for (size_t i = 0; i < m; i++) {
+ for (size_t j = 0; j < p; j++) {
+ __m256i sum_vec = _mm256_setzero_si256();
+ size_t k = 0;
+ for (; k + 31 < n; k += 32) {
+ __m256i a_lo = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k]));
+ __m256i a_hi = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16]));
+ __m256i b_lo = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j]));
+ __m256i b_hi = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16]));
+ __m256i mul_lo = _mm256_madd_epi16(a_lo, b_lo);
+ __m256i mul_hi = _mm256_madd_epi16(a_hi, b_hi);
+ sum_vec = _mm256_add_epi32(sum_vec, mul_lo);
+ sum_vec = _mm256_add_epi32(sum_vec, mul_hi);
+ }
+ int s = reduce_add_i32x8(sum_vec);
+ for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j];
+ if (scale != 0 && scale != 1) s = (int)((double)s / scale);
+ C[i * p + j] = (double)s;
+ }
}
- return impl(m, n, p, A, B, C, scale);
+ return 0;
}
-static int _matmul_f32_u8_f64(size_t m, size_t n, size_t p, const float *A, const uint8_t *B, double *C, double scale) {
- static int (*impl)(size_t, size_t, size_t, const float *, const uint8_t *, double *, double) = NULL;
- static int initialized = 0;
- if (!initialized) {
- impl = matmul_scalar_f32_u8_f64;
- initialized = 1;
+int matmul_avx2_u8_u8_i8(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, int8_t *C, double scale) {
+#ifdef _OPENMP
+#pragma omp parallel for schedule(static)
+#endif
+ for (size_t i = 0; i < m; i++) {
+ for (size_t j = 0; j < p; j++) {
+ __m256i sum_vec = _mm256_setzero_si256();
+ size_t k = 0;
+ for (; k + 31 < n; k += 32) {
+ __m256i a_lo = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k]));
+ __m256i a_hi = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16]));
+ __m256i b_lo = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j]));
+ __m256i b_hi = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16]));
+ __m256i mul_lo = _mm256_madd_epi16(a_lo, b_lo);
+ __m256i mul_hi = _mm256_madd_epi16(a_hi, b_hi);
+ sum_vec = _mm256_add_epi32(sum_vec, mul_lo);
+ sum_vec = _mm256_add_epi32(sum_vec, mul_hi);
+ }
+ int s = reduce_add_i32x8(sum_vec);
+ for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j];
+ if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale);
+ if (s > 127)
+ s = 127;
+ else if (s < -128)
+ s = -128;
+ C[i * p + j] = (int8_t)s;
+ }
}
- return impl(m, n, p, A, B, C, scale);
+ return 0;
}
-static int _matmul_f32_u8_i8(size_t m, size_t n, size_t p, const float *A, const uint8_t *B, int8_t *C, double scale) {
- static int (*impl)(size_t, size_t, size_t, const float *, const uint8_t *, int8_t *, double) = NULL;
- static int initialized = 0;
- if (!initialized) {
- impl = matmul_scalar_f32_u8_i8;
- initialized = 1;
+int matmul_avx2_u8_u8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, uint8_t *C, double scale) {
+#ifdef _OPENMP
+#pragma omp parallel for schedule(static)
+#endif
+ for (size_t i = 0; i < m; i++) {
+ for (size_t j = 0; j < p; j++) {
+ __m256i sum_vec = _mm256_setzero_si256();
+ size_t k = 0;
+ for (; k + 31 < n; k += 32) {
+ __m256i a_lo = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k]));
+ __m256i a_hi = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16]));
+ __m256i b_lo = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j]));
+ __m256i b_hi = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16]));
+ __m256i mul_lo = _mm256_madd_epi16(a_lo, b_lo);
+ __m256i mul_hi = _mm256_madd_epi16(a_hi, b_hi);
+ sum_vec = _mm256_add_epi32(sum_vec, mul_lo);
+ sum_vec = _mm256_add_epi32(sum_vec, mul_hi);
+ }
+ int s = reduce_add_i32x8(sum_vec);
+ for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j];
+ if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale);
+ if (s > 255)
+ s = 255;
+ else if (s < 0)
+ s = 0;
+ C[i * p + j] = (uint8_t)s;
+ }
}
- return impl(m, n, p, A, B, C, scale);
+ return 0;
}
+#endif
-static int _matmul_f32_u8_u8(size_t m, size_t n, size_t p, const float *A, const uint8_t *B, uint8_t *C, double scale) {
- static int (*impl)(size_t, size_t, size_t, const float *, const uint8_t *, uint8_t *, double) = NULL;
- static int initialized = 0;
- if (!initialized) {
- impl = matmul_scalar_f32_u8_u8;
- initialized = 1;
- }
- return impl(m, n, p, A, B, C, scale);
+#ifdef __AVX512F__
+static inline int32_t reduce_add_i32x16(__m512i v) {
+ __m256i low = _mm512_extracti64x4_epi64(v, 0);
+ __m256i high = _mm512_extracti64x4_epi64(v, 1);
+ __m256i sum = _mm256_add_epi32(low, high);
+ sum = _mm256_hadd_epi32(sum, sum);
+ sum = _mm256_hadd_epi32(sum, sum);
+ __m128i s128 = _mm256_extracti128_si256(sum, 0);
+ s128 = _mm_hadd_epi32(s128, s128);
+ s128 = _mm_hadd_epi32(s128, s128);
+ return _mm_cvtsi128_si32(s128);
}
-static int _matmul_f64_f32_f32(size_t m, size_t n, size_t p, const double *A, const float *B, float *C, double scale) {
- static int (*impl)(size_t, size_t, size_t, const double *, const float *, float *, double) = NULL;
- static int initialized = 0;
- if (!initialized) {
- matmul_feature_t feat = matmul_get_feature();
-#ifdef __AVX512F__
- if (feat & MATMUL_FLAG_AVX512)
- impl = matmul_avx512_f64_f32_f32;
- else
-#endif
-#ifdef __AVX2__
- if (feat & MATMUL_FLAG_AVX2)
- impl = matmul_avx2_f64_f32_f32;
- else
+int matmul_avx512_i8_i8_f32(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, float *C, double scale) {
+#ifdef _OPENMP
+#pragma omp parallel for schedule(static)
#endif
- impl = matmul_scalar_f64_f32_f32;
- initialized = 1;
+ for (size_t i = 0; i < m; i++) {
+ for (size_t j = 0; j < p; j++) {
+ __m512i sum_vec = _mm512_setzero_si512();
+ size_t k = 0;
+ for (; k + 63 < n; k += 64) {
+ __m512i a0 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k]));
+ __m512i a1 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16]));
+ __m512i a2 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 32]));
+ __m512i a3 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 48]));
+ __m512i b0 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j]));
+ __m512i b1 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16]));
+ __m512i b2 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 32]));
+ __m512i b3 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 48]));
+ sum_vec = _mm512_dpwssd_epi32(sum_vec, a0, b0);
+ sum_vec = _mm512_dpwssd_epi32(sum_vec, a1, b1);
+ sum_vec = _mm512_dpwssd_epi32(sum_vec, a2, b2);
+ sum_vec = _mm512_dpwssd_epi32(sum_vec, a3, b3);
+ }
+ int s = reduce_add_i32x16(sum_vec);
+ for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j];
+ if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale);
+ C[i * p + j] = (float)s;
+ }
}
- return impl(m, n, p, A, B, C, scale);
+ return 0;
}
-static int _matmul_f64_f32_f64(size_t m, size_t n, size_t p, const double *A, const float *B, double *C, double scale) {
- static int (*impl)(size_t, size_t, size_t, const double *, const float *, double *, double) = NULL;
- static int initialized = 0;
- if (!initialized) {
- matmul_feature_t feat = matmul_get_feature();
-#ifdef __AVX512F__
- if (feat & MATMUL_FLAG_AVX512)
- impl = matmul_avx512_f64_f32_f64;
- else
+int matmul_avx512_i8_i8_f64(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, double *C, double scale) {
+#ifdef _OPENMP
+#pragma omp parallel for schedule(static)
+#endif
+ for (size_t i = 0; i < m; i++) {
+ for (size_t j = 0; j < p; j++) {
+ __m512i sum_vec = _mm512_setzero_si512();
+ size_t k = 0;
+ for (; k + 63 < n; k += 64) {
+ __m512i a0 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k]));
+ __m512i a1 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16]));
+ __m512i a2 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 32]));
+ __m512i a3 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 48]));
+ __m512i b0 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j]));
+ __m512i b1 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16]));
+ __m512i b2 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 32]));
+ __m512i b3 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 48]));
+ sum_vec = _mm512_dpwssd_epi32(sum_vec, a0, b0);
+ sum_vec = _mm512_dpwssd_epi32(sum_vec, a1, b1);
+ sum_vec = _mm512_dpwssd_epi32(sum_vec, a2, b2);
+ sum_vec = _mm512_dpwssd_epi32(sum_vec, a3, b3);
+ }
+ int s = reduce_add_i32x16(sum_vec);
+ for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j];
+ if (scale != 0 && scale != 1) s = (int)((double)s / scale);
+ C[i * p + j] = (double)s;
+ }
+ }
+ return 0;
+}
+
+int matmul_avx512_i8_i8_i8(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, int8_t *C, double scale) {
+#ifdef _OPENMP
+#pragma omp parallel for schedule(static)
+#endif
+ for (size_t i = 0; i < m; i++) {
+ for (size_t j = 0; j < p; j++) {
+ __m512i sum_vec = _mm512_setzero_si512();
+ size_t k = 0;
+ for (; k + 63 < n; k += 64) {
+ __m512i a0 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k]));
+ __m512i a1 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16]));
+ __m512i a2 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 32]));
+ __m512i a3 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 48]));
+ __m512i b0 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j]));
+ __m512i b1 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16]));
+ __m512i b2 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 32]));
+ __m512i b3 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 48]));
+ sum_vec = _mm512_dpwssd_epi32(sum_vec, a0, b0);
+ sum_vec = _mm512_dpwssd_epi32(sum_vec, a1, b1);
+ sum_vec = _mm512_dpwssd_epi32(sum_vec, a2, b2);
+ sum_vec = _mm512_dpwssd_epi32(sum_vec, a3, b3);
+ }
+ int s = reduce_add_i32x16(sum_vec);
+ for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j];
+ if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale);
+ if (s > 127)
+ s = 127;
+ else if (s < -128)
+ s = -128;
+ C[i * p + j] = (int8_t)s;
+ }
+ }
+ return 0;
+}
+
+int matmul_avx512_i8_i8_u8(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, uint8_t *C, double scale) {
+#ifdef _OPENMP
+#pragma omp parallel for schedule(static)
+#endif
+ for (size_t i = 0; i < m; i++) {
+ for (size_t j = 0; j < p; j++) {
+ __m512i sum_vec = _mm512_setzero_si512();
+ size_t k = 0;
+ for (; k + 63 < n; k += 64) {
+ __m512i a0 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k]));
+ __m512i a1 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16]));
+ __m512i a2 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 32]));
+ __m512i a3 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 48]));
+ __m512i b0 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j]));
+ __m512i b1 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16]));
+ __m512i b2 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 32]));
+ __m512i b3 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 48]));
+ sum_vec = _mm512_dpwssd_epi32(sum_vec, a0, b0);
+ sum_vec = _mm512_dpwssd_epi32(sum_vec, a1, b1);
+ sum_vec = _mm512_dpwssd_epi32(sum_vec, a2, b2);
+ sum_vec = _mm512_dpwssd_epi32(sum_vec, a3, b3);
+ }
+ int s = reduce_add_i32x16(sum_vec);
+ for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j];
+ if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale);
+ if (s > 255)
+ s = 255;
+ else if (s < 0)
+ s = 0;
+ C[i * p + j] = (uint8_t)s;
+ }
+ }
+ return 0;
+}
+
+int matmul_avx512_i8_u8_f32(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, float *C, double scale) {
+#ifdef _OPENMP
+#pragma omp parallel for schedule(static)
+#endif
+ for (size_t i = 0; i < m; i++) {
+ for (size_t j = 0; j < p; j++) {
+ __m512i sum_vec = _mm512_setzero_si512();
+ size_t k = 0;
+ for (; k + 63 < n; k += 64) {
+ __m512i a0 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k]));
+ __m512i a1 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16]));
+ __m512i a2 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 32]));
+ __m512i a3 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 48]));
+ __m512i b0 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j]));
+ __m512i b1 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16]));
+ __m512i b2 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 32]));
+ __m512i b3 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 48]));
+ sum_vec = _mm512_dpwssd_epi32(sum_vec, a0, b0);
+ sum_vec = _mm512_dpwssd_epi32(sum_vec, a1, b1);
+ sum_vec = _mm512_dpwssd_epi32(sum_vec, a2, b2);
+ sum_vec = _mm512_dpwssd_epi32(sum_vec, a3, b3);
+ }
+ int s = reduce_add_i32x16(sum_vec);
+ for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j];
+ if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale);
+ C[i * p + j] = (float)s;
+ }
+ }
+ return 0;
+}
+
+int matmul_avx512_i8_u8_f64(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, double *C, double scale) {
+#ifdef _OPENMP
+#pragma omp parallel for schedule(static)
+#endif
+ for (size_t i = 0; i < m; i++) {
+ for (size_t j = 0; j < p; j++) {
+ __m512i sum_vec = _mm512_setzero_si512();
+ size_t k = 0;
+ for (; k + 63 < n; k += 64) {
+ __m512i a0 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k]));
+ __m512i a1 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16]));
+ __m512i a2 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 32]));
+ __m512i a3 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 48]));
+ __m512i b0 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j]));
+ __m512i b1 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16]));
+ __m512i b2 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 32]));
+ __m512i b3 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 48]));
+ sum_vec = _mm512_dpwssd_epi32(sum_vec, a0, b0);
+ sum_vec = _mm512_dpwssd_epi32(sum_vec, a1, b1);
+ sum_vec = _mm512_dpwssd_epi32(sum_vec, a2, b2);
+ sum_vec = _mm512_dpwssd_epi32(sum_vec, a3, b3);
+ }
+ int s = reduce_add_i32x16(sum_vec);
+ for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j];
+ if (scale != 0 && scale != 1) s = (int)((double)s / scale);
+ C[i * p + j] = (double)s;
+ }
+ }
+ return 0;
+}
+
+int matmul_avx512_i8_u8_i8(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, int8_t *C, double scale) {
+#ifdef _OPENMP
+#pragma omp parallel for schedule(static)
+#endif
+ for (size_t i = 0; i < m; i++) {
+ for (size_t j = 0; j < p; j++) {
+ __m512i sum_vec = _mm512_setzero_si512();
+ size_t k = 0;
+ for (; k + 63 < n; k += 64) {
+ __m512i a0 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k]));
+ __m512i a1 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16]));
+ __m512i a2 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 32]));
+ __m512i a3 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 48]));
+ __m512i b0 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j]));
+ __m512i b1 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16]));
+ __m512i b2 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 32]));
+ __m512i b3 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 48]));
+ sum_vec = _mm512_dpwssd_epi32(sum_vec, a0, b0);
+ sum_vec = _mm512_dpwssd_epi32(sum_vec, a1, b1);
+ sum_vec = _mm512_dpwssd_epi32(sum_vec, a2, b2);
+ sum_vec = _mm512_dpwssd_epi32(sum_vec, a3, b3);
+ }
+ int s = reduce_add_i32x16(sum_vec);
+ for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j];
+ if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale);
+ if (s > 127)
+ s = 127;
+ else if (s < -128)
+ s = -128;
+ C[i * p + j] = (int8_t)s;
+ }
+ }
+ return 0;
+}
+
+int matmul_avx512_i8_u8_u8(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, uint8_t *C, double scale) {
+#ifdef _OPENMP
+#pragma omp parallel for schedule(static)
+#endif
+ for (size_t i = 0; i < m; i++) {
+ for (size_t j = 0; j < p; j++) {
+ __m512i sum_vec = _mm512_setzero_si512();
+ size_t k = 0;
+ for (; k + 63 < n; k += 64) {
+ __m512i a0 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k]));
+ __m512i a1 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16]));
+ __m512i a2 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 32]));
+ __m512i a3 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 48]));
+ __m512i b0 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j]));
+ __m512i b1 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16]));
+ __m512i b2 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 32]));
+ __m512i b3 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 48]));
+ sum_vec = _mm512_dpwssd_epi32(sum_vec, a0, b0);
+ sum_vec = _mm512_dpwssd_epi32(sum_vec, a1, b1);
+ sum_vec = _mm512_dpwssd_epi32(sum_vec, a2, b2);
+ sum_vec = _mm512_dpwssd_epi32(sum_vec, a3, b3);
+ }
+ int s = reduce_add_i32x16(sum_vec);
+ for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j];
+ if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale);
+ if (s > 255)
+ s = 255;
+ else if (s < 0)
+ s = 0;
+ C[i * p + j] = (uint8_t)s;
+ }
+ }
+ return 0;
+}
+
+int matmul_avx512_u8_i8_f32(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, float *C, double scale) {
+#ifdef _OPENMP
+#pragma omp parallel for schedule(static)
+#endif
+ for (size_t i = 0; i < m; i++) {
+ for (size_t j = 0; j < p; j++) {
+ __m512i sum_vec = _mm512_setzero_si512();
+ size_t k = 0;
+ for (; k + 63 < n; k += 64) {
+ __m512i a0 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k]));
+ __m512i a1 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16]));
+ __m512i a2 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 32]));
+ __m512i a3 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 48]));
+ __m512i b0 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j]));
+ __m512i b1 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16]));
+ __m512i b2 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 32]));
+ __m512i b3 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 48]));
+ sum_vec = _mm512_dpwssd_epi32(sum_vec, a0, b0);
+ sum_vec = _mm512_dpwssd_epi32(sum_vec, a1, b1);
+ sum_vec = _mm512_dpwssd_epi32(sum_vec, a2, b2);
+ sum_vec = _mm512_dpwssd_epi32(sum_vec, a3, b3);
+ }
+ int s = reduce_add_i32x16(sum_vec);
+ for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j];
+ if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale);
+ C[i * p + j] = (float)s;
+ }
+ }
+ return 0;
+}
+
+int matmul_avx512_u8_i8_f64(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, double *C, double scale) {
+#ifdef _OPENMP
+#pragma omp parallel for schedule(static)
+#endif
+ for (size_t i = 0; i < m; i++) {
+ for (size_t j = 0; j < p; j++) {
+ __m512i sum_vec = _mm512_setzero_si512();
+ size_t k = 0;
+ for (; k + 63 < n; k += 64) {
+ __m512i a0 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k]));
+ __m512i a1 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16]));
+ __m512i a2 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 32]));
+ __m512i a3 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 48]));
+ __m512i b0 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j]));
+ __m512i b1 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16]));
+ __m512i b2 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 32]));
+ __m512i b3 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 48]));
+ sum_vec = _mm512_dpwssd_epi32(sum_vec, a0, b0);
+ sum_vec = _mm512_dpwssd_epi32(sum_vec, a1, b1);
+ sum_vec = _mm512_dpwssd_epi32(sum_vec, a2, b2);
+ sum_vec = _mm512_dpwssd_epi32(sum_vec, a3, b3);
+ }
+ int s = reduce_add_i32x16(sum_vec);
+ for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j];
+ if (scale != 0 && scale != 1) s = (int)((double)s / scale);
+ C[i * p + j] = (double)s;
+ }
+ }
+ return 0;
+}
+
+int matmul_avx512_u8_i8_i8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, int8_t *C, double scale) {
+#ifdef _OPENMP
+#pragma omp parallel for schedule(static)
+#endif
+ for (size_t i = 0; i < m; i++) {
+ for (size_t j = 0; j < p; j++) {
+ __m512i sum_vec = _mm512_setzero_si512();
+ size_t k = 0;
+ for (; k + 63 < n; k += 64) {
+ __m512i a0 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k]));
+ __m512i a1 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16]));
+ __m512i a2 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 32]));
+ __m512i a3 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 48]));
+ __m512i b0 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j]));
+ __m512i b1 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16]));
+ __m512i b2 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 32]));
+ __m512i b3 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 48]));
+ sum_vec = _mm512_dpwssd_epi32(sum_vec, a0, b0);
+ sum_vec = _mm512_dpwssd_epi32(sum_vec, a1, b1);
+ sum_vec = _mm512_dpwssd_epi32(sum_vec, a2, b2);
+ sum_vec = _mm512_dpwssd_epi32(sum_vec, a3, b3);
+ }
+ int s = reduce_add_i32x16(sum_vec);
+ for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j];
+ if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale);
+ if (s > 127)
+ s = 127;
+ else if (s < -128)
+ s = -128;
+ C[i * p + j] = (int8_t)s;
+ }
+ }
+ return 0;
+}
+
+int matmul_avx512_u8_i8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, uint8_t *C, double scale) {
+#ifdef _OPENMP
+#pragma omp parallel for schedule(static)
+#endif
+ for (size_t i = 0; i < m; i++) {
+ for (size_t j = 0; j < p; j++) {
+ __m512i sum_vec = _mm512_setzero_si512();
+ size_t k = 0;
+ for (; k + 63 < n; k += 64) {
+ __m512i a0 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k]));
+ __m512i a1 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16]));
+ __m512i a2 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 32]));
+ __m512i a3 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 48]));
+ __m512i b0 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j]));
+ __m512i b1 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16]));
+ __m512i b2 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 32]));
+ __m512i b3 = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 48]));
+ sum_vec = _mm512_dpwssd_epi32(sum_vec, a0, b0);
+ sum_vec = _mm512_dpwssd_epi32(sum_vec, a1, b1);
+ sum_vec = _mm512_dpwssd_epi32(sum_vec, a2, b2);
+ sum_vec = _mm512_dpwssd_epi32(sum_vec, a3, b3);
+ }
+ int s = reduce_add_i32x16(sum_vec);
+ for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j];
+ if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale);
+ if (s > 255)
+ s = 255;
+ else if (s < 0)
+ s = 0;
+ C[i * p + j] = (uint8_t)s;
+ }
+ }
+ return 0;
+}
+
+int matmul_avx512_u8_u8_f32(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, float *C, double scale) {
+#ifdef _OPENMP
+#pragma omp parallel for schedule(static)
+#endif
+ for (size_t i = 0; i < m; i++) {
+ for (size_t j = 0; j < p; j++) {
+ __m512i sum_vec = _mm512_setzero_si512();
+ size_t k = 0;
+ for (; k + 63 < n; k += 64) {
+ __m512i a0 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k]));
+ __m512i a1 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16]));
+ __m512i a2 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 32]));
+ __m512i a3 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 48]));
+ __m512i b0 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j]));
+ __m512i b1 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16]));
+ __m512i b2 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 32]));
+ __m512i b3 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 48]));
+ sum_vec = _mm512_dpwssd_epi32(sum_vec, a0, b0);
+ sum_vec = _mm512_dpwssd_epi32(sum_vec, a1, b1);
+ sum_vec = _mm512_dpwssd_epi32(sum_vec, a2, b2);
+ sum_vec = _mm512_dpwssd_epi32(sum_vec, a3, b3);
+ }
+ int s = reduce_add_i32x16(sum_vec);
+ for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j];
+ if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale);
+ C[i * p + j] = (float)s;
+ }
+ }
+ return 0;
+}
+
+int matmul_avx512_u8_u8_f64(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, double *C, double scale) {
+#ifdef _OPENMP
+#pragma omp parallel for schedule(static)
+#endif
+ for (size_t i = 0; i < m; i++) {
+ for (size_t j = 0; j < p; j++) {
+ __m512i sum_vec = _mm512_setzero_si512();
+ size_t k = 0;
+ for (; k + 63 < n; k += 64) {
+ __m512i a0 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k]));
+ __m512i a1 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16]));
+ __m512i a2 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 32]));
+ __m512i a3 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 48]));
+ __m512i b0 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j]));
+ __m512i b1 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16]));
+ __m512i b2 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 32]));
+ __m512i b3 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 48]));
+ sum_vec = _mm512_dpwssd_epi32(sum_vec, a0, b0);
+ sum_vec = _mm512_dpwssd_epi32(sum_vec, a1, b1);
+ sum_vec = _mm512_dpwssd_epi32(sum_vec, a2, b2);
+ sum_vec = _mm512_dpwssd_epi32(sum_vec, a3, b3);
+ }
+ int s = reduce_add_i32x16(sum_vec);
+ for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j];
+ if (scale != 0 && scale != 1) s = (int)((double)s / scale);
+ C[i * p + j] = (double)s;
+ }
+ }
+ return 0;
+}
+
+int matmul_avx512_u8_u8_i8(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, int8_t *C, double scale) {
+#ifdef _OPENMP
+#pragma omp parallel for schedule(static)
+#endif
+ for (size_t i = 0; i < m; i++) {
+ for (size_t j = 0; j < p; j++) {
+ __m512i sum_vec = _mm512_setzero_si512();
+ size_t k = 0;
+ for (; k + 63 < n; k += 64) {
+ __m512i a0 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k]));
+ __m512i a1 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16]));
+ __m512i a2 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 32]));
+ __m512i a3 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 48]));
+ __m512i b0 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j]));
+ __m512i b1 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16]));
+ __m512i b2 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 32]));
+ __m512i b3 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 48]));
+ sum_vec = _mm512_dpwssd_epi32(sum_vec, a0, b0);
+ sum_vec = _mm512_dpwssd_epi32(sum_vec, a1, b1);
+ sum_vec = _mm512_dpwssd_epi32(sum_vec, a2, b2);
+ sum_vec = _mm512_dpwssd_epi32(sum_vec, a3, b3);
+ }
+ int s = reduce_add_i32x16(sum_vec);
+ for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j];
+ if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale);
+ if (s > 127)
+ s = 127;
+ else if (s < -128)
+ s = -128;
+ C[i * p + j] = (int8_t)s;
+ }
+ }
+ return 0;
+}
+
+int matmul_avx512_u8_u8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, uint8_t *C, double scale) {
+#ifdef _OPENMP
+#pragma omp parallel for schedule(static)
+#endif
+ for (size_t i = 0; i < m; i++) {
+ for (size_t j = 0; j < p; j++) {
+ __m512i sum_vec = _mm512_setzero_si512();
+ size_t k = 0;
+ for (; k + 63 < n; k += 64) {
+ __m512i a0 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k]));
+ __m512i a1 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 16]));
+ __m512i a2 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 32]));
+ __m512i a3 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&A[i * n + k + 48]));
+ __m512i b0 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j]));
+ __m512i b1 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 16]));
+ __m512i b2 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 32]));
+ __m512i b3 = _mm512_cvtepu8_epi32(_mm_loadu_si128((const __m128i *)&B[k * p + j + 48]));
+ sum_vec = _mm512_dpwssd_epi32(sum_vec, a0, b0);
+ sum_vec = _mm512_dpwssd_epi32(sum_vec, a1, b1);
+ sum_vec = _mm512_dpwssd_epi32(sum_vec, a2, b2);
+ sum_vec = _mm512_dpwssd_epi32(sum_vec, a3, b3);
+ }
+ int s = reduce_add_i32x16(sum_vec);
+ for (; k < n; k++) s += (int)A[i * n + k] * (int)B[k * p + j];
+ if (scale != 0 && scale != 1) s = (int)((float)s / (float)scale);
+ if (s > 255)
+ s = 255;
+ else if (s < 0)
+ s = 0;
+ C[i * p + j] = (uint8_t)s;
+ }
+ }
+ return 0;
+}
+#endif
+
+int matmul_scalar_u8_u8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, uint8_t *C, double scale) {
+#ifdef _OPENMP
+#pragma omp parallel for schedule(static)
+#endif
+ for (size_t i = 0; i < m; i++) {
+ for (size_t j = 0; j < p; j++) {
+ int sum = 0;
+ for (size_t k = 0; k < n; k++) {
+ sum += (int)A[i * n + k] * (int)B[k * p + j];
+ }
+ if (scale != 0 && scale != 1) sum = (int)(sum / scale);
+ if (sum > 255)
+ sum = 255;
+ else if (sum < 0)
+ sum = 0;
+ C[i * p + j] = (uint8_t)sum;
+ }
+ }
+ return 0;
+}
+
+static int _matmul_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const float *B, float *C, double scale) {
+ static int initialized = 0;
+ if (!initialized) {
+ matmul_feature_t feat = matmul_get_feature();
+#ifdef __AVX512F__
+ if (feat & MATMUL_FLAG_AVX512)
+ matmul_f32_f32_f32 = matmul_avx512_f32_f32_f32;
+ else
+#endif
+#ifdef __AVX2__
+ if (feat & MATMUL_FLAG_AVX2)
+ matmul_f32_f32_f32 = matmul_avx2_f32_f32_f32;
+ else
+#endif
+ matmul_f32_f32_f32 = matmul_scalar_f32_f32_f32;
+ initialized = 1;
+ }
+ return matmul_f32_f32_f32(m, n, p, A, B, C, scale);
+}
+
+static int _matmul_f32_f32_f64(size_t m, size_t n, size_t p, const float *A, const float *B, double *C, double scale) {
+ static int initialized = 0;
+ if (!initialized) {
+ matmul_feature_t feat = matmul_get_feature();
+#ifdef __AVX512F__
+ if (feat & MATMUL_FLAG_AVX512)
+ matmul_f32_f32_f64 = matmul_avx512_f32_f32_f64;
+ else
+#endif
+#ifdef __AVX2__
+ if (feat & MATMUL_FLAG_AVX2)
+ matmul_f32_f32_f64 = matmul_avx2_f32_f32_f64;
+ else
+#endif
+ matmul_f32_f32_f64 = matmul_scalar_f32_f32_f64;
+ initialized = 1;
+ }
+ return matmul_f32_f32_f64(m, n, p, A, B, C, scale);
+}
+
+static int _matmul_f32_f32_i8(size_t m, size_t n, size_t p, const float *A, const float *B, int8_t *C, double scale) {
+ static int initialized = 0;
+ if (!initialized) {
+ matmul_f32_f32_i8 = matmul_scalar_f32_f32_i8;
+ initialized = 1;
+ }
+ return matmul_f32_f32_i8(m, n, p, A, B, C, scale);
+}
+
+static int _matmul_f32_f32_u8(size_t m, size_t n, size_t p, const float *A, const float *B, uint8_t *C, double scale) {
+ static int initialized = 0;
+ if (!initialized) {
+ matmul_f32_f32_u8 = matmul_scalar_f32_f32_u8;
+ initialized = 1;
+ }
+ return matmul_f32_f32_u8(m, n, p, A, B, C, scale);
+}
+
+static int _matmul_f32_f64_f32(size_t m, size_t n, size_t p, const float *A, const double *B, float *C, double scale) {
+ static int initialized = 0;
+ if (!initialized) {
+ matmul_feature_t feat = matmul_get_feature();
+#ifdef __AVX512F__
+ if (feat & MATMUL_FLAG_AVX512)
+ matmul_f32_f64_f32 = matmul_avx512_f32_f64_f32;
+ else
+#endif
+#ifdef __AVX2__
+ if (feat & MATMUL_FLAG_AVX2)
+ matmul_f32_f64_f32 = matmul_avx2_f32_f64_f32;
+ else
+#endif
+ matmul_f32_f64_f32 = matmul_scalar_f32_f64_f32;
+ initialized = 1;
+ }
+ return matmul_f32_f64_f32(m, n, p, A, B, C, scale);
+}
+
+static int _matmul_f32_f64_f64(size_t m, size_t n, size_t p, const float *A, const double *B, double *C, double scale) {
+ static int initialized = 0;
+ if (!initialized) {
+ matmul_feature_t feat = matmul_get_feature();
+#ifdef __AVX512F__
+ if (feat & MATMUL_FLAG_AVX512)
+ matmul_f32_f64_f64 = matmul_avx512_f32_f64_f64;
+ else
+#endif
+#ifdef __AVX2__
+ if (feat & MATMUL_FLAG_AVX2)
+ matmul_f32_f64_f64 = matmul_avx2_f32_f64_f64;
+ else
+#endif
+ matmul_f32_f64_f64 = matmul_scalar_f32_f64_f64;
+ initialized = 1;
+ }
+ return matmul_f32_f64_f64(m, n, p, A, B, C, scale);
+}
+
+static int _matmul_f32_f64_i8(size_t m, size_t n, size_t p, const float *A, const double *B, int8_t *C, double scale) {
+ static int initialized = 0;
+ if (!initialized) {
+ matmul_f32_f64_i8 = matmul_scalar_f32_f64_i8;
+ initialized = 1;
+ }
+ return matmul_f32_f64_i8(m, n, p, A, B, C, scale);
+}
+
+static int _matmul_f32_f64_u8(size_t m, size_t n, size_t p, const float *A, const double *B, uint8_t *C, double scale) {
+ static int initialized = 0;
+ if (!initialized) {
+ matmul_f32_f64_u8 = matmul_scalar_f32_f64_u8;
+ initialized = 1;
+ }
+ return matmul_f32_f64_u8(m, n, p, A, B, C, scale);
+}
+
+static int _matmul_f32_i8_f32(size_t m, size_t n, size_t p, const float *A, const int8_t *B, float *C, double scale) {
+ static int initialized = 0;
+ if (!initialized) {
+ matmul_f32_i8_f32 = matmul_scalar_f32_i8_f32;
+ initialized = 1;
+ }
+ return matmul_f32_i8_f32(m, n, p, A, B, C, scale);
+}
+
+static int _matmul_f32_i8_f64(size_t m, size_t n, size_t p, const float *A, const int8_t *B, double *C, double scale) {
+ static int initialized = 0;
+ if (!initialized) {
+ matmul_f32_i8_f64 = matmul_scalar_f32_i8_f64;
+ initialized = 1;
+ }
+ return matmul_f32_i8_f64(m, n, p, A, B, C, scale);
+}
+
+static int _matmul_f32_i8_i8(size_t m, size_t n, size_t p, const float *A, const int8_t *B, int8_t *C, double scale) {
+ static int initialized = 0;
+ if (!initialized) {
+ matmul_f32_i8_i8 = matmul_scalar_f32_i8_i8;
+ initialized = 1;
+ }
+ return matmul_f32_i8_i8(m, n, p, A, B, C, scale);
+}
+
+static int _matmul_f32_i8_u8(size_t m, size_t n, size_t p, const float *A, const int8_t *B, uint8_t *C, double scale) {
+ static int initialized = 0;
+ if (!initialized) {
+ matmul_f32_i8_u8 = matmul_scalar_f32_i8_u8;
+ initialized = 1;
+ }
+ return matmul_f32_i8_u8(m, n, p, A, B, C, scale);
+}
+
+static int _matmul_f32_u8_f32(size_t m, size_t n, size_t p, const float *A, const uint8_t *B, float *C, double scale) {
+ static int initialized = 0;
+ if (!initialized) {
+ matmul_f32_u8_f32 = matmul_scalar_f32_u8_f32;
+ initialized = 1;
+ }
+ return matmul_f32_u8_f32(m, n, p, A, B, C, scale);
+}
+
+static int _matmul_f32_u8_f64(size_t m, size_t n, size_t p, const float *A, const uint8_t *B, double *C, double scale) {
+ static int initialized = 0;
+ if (!initialized) {
+ matmul_f32_u8_f64 = matmul_scalar_f32_u8_f64;
+ initialized = 1;
+ }
+ return matmul_f32_u8_f64(m, n, p, A, B, C, scale);
+}
+
+static int _matmul_f32_u8_i8(size_t m, size_t n, size_t p, const float *A, const uint8_t *B, int8_t *C, double scale) {
+ static int initialized = 0;
+ if (!initialized) {
+ matmul_f32_u8_i8 = matmul_scalar_f32_u8_i8;
+ initialized = 1;
+ }
+ return matmul_f32_u8_i8(m, n, p, A, B, C, scale);
+}
+
+static int _matmul_f32_u8_u8(size_t m, size_t n, size_t p, const float *A, const uint8_t *B, uint8_t *C, double scale) {
+ static int initialized = 0;
+ if (!initialized) {
+ matmul_f32_u8_u8 = matmul_scalar_f32_u8_u8;
+ initialized = 1;
+ }
+ return matmul_f32_u8_u8(m, n, p, A, B, C, scale);
+}
+
+static int _matmul_f64_f32_f32(size_t m, size_t n, size_t p, const double *A, const float *B, float *C, double scale) {
+ static int initialized = 0;
+ if (!initialized) {
+ matmul_feature_t feat = matmul_get_feature();
+#ifdef __AVX512F__
+ if (feat & MATMUL_FLAG_AVX512)
+ matmul_f64_f32_f32 = matmul_avx512_f64_f32_f32;
+ else
+#endif
+#ifdef __AVX2__
+ if (feat & MATMUL_FLAG_AVX2)
+ matmul_f64_f32_f32 = matmul_avx2_f64_f32_f32;
+ else
+#endif
+ matmul_f64_f32_f32 = matmul_scalar_f64_f32_f32;
+ initialized = 1;
+ }
+ return matmul_f64_f32_f32(m, n, p, A, B, C, scale);
+}
+
+static int _matmul_f64_f32_f64(size_t m, size_t n, size_t p, const double *A, const float *B, double *C, double scale) {
+ static int initialized = 0;
+ if (!initialized) {
+ matmul_feature_t feat = matmul_get_feature();
+#ifdef __AVX512F__
+ if (feat & MATMUL_FLAG_AVX512)
+ matmul_f64_f32_f64 = matmul_avx512_f64_f32_f64;
+ else
#endif
#ifdef __AVX2__
if (feat & MATMUL_FLAG_AVX2)
- impl = matmul_avx2_f64_f32_f64;
+ matmul_f64_f32_f64 = matmul_avx2_f64_f32_f64;
else
#endif
- impl = matmul_scalar_f64_f32_f64;
+ matmul_f64_f32_f64 = matmul_scalar_f64_f32_f64;
initialized = 1;
}
- return impl(m, n, p, A, B, C, scale);
+ return matmul_f64_f32_f64(m, n, p, A, B, C, scale);
}
static int _matmul_f64_f32_i8(size_t m, size_t n, size_t p, const double *A, const float *B, int8_t *C, double scale) {
- static int (*impl)(size_t, size_t, size_t, const double *, const float *, int8_t *, double) = NULL;
- static int initialized = 0;
+ static int initialized = 0;
if (!initialized) {
- impl = matmul_scalar_f64_f32_i8;
- initialized = 1;
+ matmul_f64_f32_i8 = matmul_scalar_f64_f32_i8;
+ initialized = 1;
}
- return impl(m, n, p, A, B, C, scale);
+ return matmul_f64_f32_i8(m, n, p, A, B, C, scale);
}
static int _matmul_f64_f32_u8(size_t m, size_t n, size_t p, const double *A, const float *B, uint8_t *C, double scale) {
- static int (*impl)(size_t, size_t, size_t, const double *, const float *, uint8_t *, double) = NULL;
- static int initialized = 0;
+ static int initialized = 0;
if (!initialized) {
- impl = matmul_scalar_f64_f32_u8;
- initialized = 1;
+ matmul_f64_f32_u8 = matmul_scalar_f64_f32_u8;
+ initialized = 1;
}
- return impl(m, n, p, A, B, C, scale);
+ return matmul_f64_f32_u8(m, n, p, A, B, C, scale);
}
static int _matmul_f64_f64_f32(size_t m, size_t n, size_t p, const double *A, const double *B, float *C, double scale) {
- static int (*impl)(size_t, size_t, size_t, const double *, const double *, float *, double) = NULL;
- static int initialized = 0;
+ static int initialized = 0;
if (!initialized) {
matmul_feature_t feat = matmul_get_feature();
#ifdef __AVX512F__
if (feat & MATMUL_FLAG_AVX512)
- impl = matmul_avx512_f64_f64_f32;
+ matmul_f64_f64_f32 = matmul_avx512_f64_f64_f32;
else
#endif
#ifdef __AVX2__
if (feat & MATMUL_FLAG_AVX2)
- impl = matmul_avx2_f64_f64_f32;
+ matmul_f64_f64_f32 = matmul_avx2_f64_f64_f32;
else
#endif
- impl = matmul_scalar_f64_f64_f32;
+ matmul_f64_f64_f32 = matmul_scalar_f64_f64_f32;
initialized = 1;
}
- return impl(m, n, p, A, B, C, scale);
+ return matmul_f64_f64_f32(m, n, p, A, B, C, scale);
}
static int _matmul_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, const double *B, double *C,
double scale) {
- static int (*impl)(size_t, size_t, size_t, const double *, const double *, double *, double) = NULL;
- static int initialized = 0;
+ static int initialized = 0;
if (!initialized) {
matmul_feature_t feat = matmul_get_feature();
#ifdef __AVX512F__
if (feat & MATMUL_FLAG_AVX512)
- impl = matmul_avx512_f64_f64_f64;
+ matmul_f64_f64_f64 = matmul_avx512_f64_f64_f64;
else
#endif
#ifdef __AVX2__
if (feat & MATMUL_FLAG_AVX2)
- impl = matmul_avx2_f64_f64_f64;
+ matmul_f64_f64_f64 = matmul_avx2_f64_f64_f64;
else
#endif
- impl = matmul_scalar_f64_f64_f64;
+ matmul_f64_f64_f64 = matmul_scalar_f64_f64_f64;
initialized = 1;
}
- return impl(m, n, p, A, B, C, scale);
+ return matmul_f64_f64_f64(m, n, p, A, B, C, scale);
}
static int _matmul_f64_f64_i8(size_t m, size_t n, size_t p, const double *A, const double *B, int8_t *C, double scale) {
- static int (*impl)(size_t, size_t, size_t, const double *, const double *, int8_t *, double) = NULL;
- static int initialized = 0;
+ static int initialized = 0;
if (!initialized) {
- impl = matmul_scalar_f64_f64_i8;
- initialized = 1;
+ matmul_f64_f64_i8 = matmul_scalar_f64_f64_i8;
+ initialized = 1;
}
- return impl(m, n, p, A, B, C, scale);
+ return matmul_f64_f64_i8(m, n, p, A, B, C, scale);
}
static int _matmul_f64_f64_u8(size_t m, size_t n, size_t p, const double *A, const double *B, uint8_t *C,
double scale) {
- static int (*impl)(size_t, size_t, size_t, const double *, const double *, uint8_t *, double) = NULL;
- static int initialized = 0;
+ static int initialized = 0;
if (!initialized) {
- impl = matmul_scalar_f64_f64_u8;
- initialized = 1;
+ matmul_f64_f64_u8 = matmul_scalar_f64_f64_u8;
+ initialized = 1;
}
- return impl(m, n, p, A, B, C, scale);
+ return matmul_f64_f64_u8(m, n, p, A, B, C, scale);
}
static int _matmul_f64_i8_f32(size_t m, size_t n, size_t p, const double *A, const int8_t *B, float *C, double scale) {
- static int (*impl)(size_t, size_t, size_t, const double *, const int8_t *, float *, double) = NULL;
- static int initialized = 0;
+ static int initialized = 0;
if (!initialized) {
- impl = matmul_scalar_f64_i8_f32;
- initialized = 1;
+ matmul_f64_i8_f32 = matmul_scalar_f64_i8_f32;
+ initialized = 1;
}
- return impl(m, n, p, A, B, C, scale);
+ return matmul_f64_i8_f32(m, n, p, A, B, C, scale);
}
static int _matmul_f64_i8_f64(size_t m, size_t n, size_t p, const double *A, const int8_t *B, double *C, double scale) {
- static int (*impl)(size_t, size_t, size_t, const double *, const int8_t *, double *, double) = NULL;
- static int initialized = 0;
+ static int initialized = 0;
if (!initialized) {
- impl = matmul_scalar_f64_i8_f64;
- initialized = 1;
+ matmul_f64_i8_f64 = matmul_scalar_f64_i8_f64;
+ initialized = 1;
}
- return impl(m, n, p, A, B, C, scale);
+ return matmul_f64_i8_f64(m, n, p, A, B, C, scale);
}
static int _matmul_f64_i8_i8(size_t m, size_t n, size_t p, const double *A, const int8_t *B, int8_t *C, double scale) {
- static int (*impl)(size_t, size_t, size_t, const double *, const int8_t *, int8_t *, double) = NULL;
- static int initialized = 0;
+ static int initialized = 0;
if (!initialized) {
- impl = matmul_scalar_f64_i8_i8;
- initialized = 1;
+ matmul_f64_i8_i8 = matmul_scalar_f64_i8_i8;
+ initialized = 1;
}
- return impl(m, n, p, A, B, C, scale);
+ return matmul_f64_i8_i8(m, n, p, A, B, C, scale);
}
static int _matmul_f64_i8_u8(size_t m, size_t n, size_t p, const double *A, const int8_t *B, uint8_t *C, double scale) {
- static int (*impl)(size_t, size_t, size_t, const double *, const int8_t *, uint8_t *, double) = NULL;
- static int initialized = 0;
+ static int initialized = 0;
if (!initialized) {
- impl = matmul_scalar_f64_i8_u8;
- initialized = 1;
+ matmul_f64_i8_u8 = matmul_scalar_f64_i8_u8;
+ initialized = 1;
}
- return impl(m, n, p, A, B, C, scale);
+ return matmul_f64_i8_u8(m, n, p, A, B, C, scale);
}
static int _matmul_f64_u8_f32(size_t m, size_t n, size_t p, const double *A, const uint8_t *B, float *C, double scale) {
- static int (*impl)(size_t, size_t, size_t, const double *, const uint8_t *, float *, double) = NULL;
- static int initialized = 0;
+ static int initialized = 0;
if (!initialized) {
- impl = matmul_scalar_f64_u8_f32;
- initialized = 1;
+ matmul_f64_u8_f32 = matmul_scalar_f64_u8_f32;
+ initialized = 1;
}
- return impl(m, n, p, A, B, C, scale);
+ return matmul_f64_u8_f32(m, n, p, A, B, C, scale);
}
static int _matmul_f64_u8_f64(size_t m, size_t n, size_t p, const double *A, const uint8_t *B, double *C,
double scale) {
- static int (*impl)(size_t, size_t, size_t, const double *, const uint8_t *, double *, double) = NULL;
- static int initialized = 0;
+ static int initialized = 0;
if (!initialized) {
- impl = matmul_scalar_f64_u8_f64;
- initialized = 1;
+ matmul_f64_u8_f64 = matmul_scalar_f64_u8_f64;
+ initialized = 1;
}
- return impl(m, n, p, A, B, C, scale);
+ return matmul_f64_u8_f64(m, n, p, A, B, C, scale);
}
static int _matmul_f64_u8_i8(size_t m, size_t n, size_t p, const double *A, const uint8_t *B, int8_t *C, double scale) {
- static int (*impl)(size_t, size_t, size_t, const double *, const uint8_t *, int8_t *, double) = NULL;
- static int initialized = 0;
+ static int initialized = 0;
if (!initialized) {
- impl = matmul_scalar_f64_u8_i8;
- initialized = 1;
+ matmul_f64_u8_i8 = matmul_scalar_f64_u8_i8;
+ initialized = 1;
}
- return impl(m, n, p, A, B, C, scale);
+ return matmul_f64_u8_i8(m, n, p, A, B, C, scale);
}
static int _matmul_f64_u8_u8(size_t m, size_t n, size_t p, const double *A, const uint8_t *B, uint8_t *C,
double scale) {
- static int (*impl)(size_t, size_t, size_t, const double *, const uint8_t *, uint8_t *, double) = NULL;
- static int initialized = 0;
+ static int initialized = 0;
if (!initialized) {
- impl = matmul_scalar_f64_u8_u8;
- initialized = 1;
+ matmul_f64_u8_u8 = matmul_scalar_f64_u8_u8;
+ initialized = 1;
}
- return impl(m, n, p, A, B, C, scale);
+ return matmul_f64_u8_u8(m, n, p, A, B, C, scale);
}
static int _matmul_i8_f32_f32(size_t m, size_t n, size_t p, const int8_t *A, const float *B, float *C, double scale) {
- static int (*impl)(size_t, size_t, size_t, const int8_t *, const float *, float *, double) = NULL;
- static int initialized = 0;
+ static int initialized = 0;
if (!initialized) {
- impl = matmul_scalar_i8_f32_f32;
- initialized = 1;
+ matmul_i8_f32_f32 = matmul_scalar_i8_f32_f32;
+ initialized = 1;
}
- return impl(m, n, p, A, B, C, scale);
+ return matmul_i8_f32_f32(m, n, p, A, B, C, scale);
}
static int _matmul_i8_f32_f64(size_t m, size_t n, size_t p, const int8_t *A, const float *B, double *C, double scale) {
- static int (*impl)(size_t, size_t, size_t, const int8_t *, const float *, double *, double) = NULL;
- static int initialized = 0;
+ static int initialized = 0;
if (!initialized) {
- impl = matmul_scalar_i8_f32_f64;
- initialized = 1;
+ matmul_i8_f32_f64 = matmul_scalar_i8_f32_f64;
+ initialized = 1;
}
- return impl(m, n, p, A, B, C, scale);
+ return matmul_i8_f32_f64(m, n, p, A, B, C, scale);
}
static int _matmul_i8_f32_i8(size_t m, size_t n, size_t p, const int8_t *A, const float *B, int8_t *C, double scale) {
- static int (*impl)(size_t, size_t, size_t, const int8_t *, const float *, int8_t *, double) = NULL;
- static int initialized = 0;
+ static int initialized = 0;
if (!initialized) {
- impl = matmul_scalar_i8_f32_i8;
- initialized = 1;
+ matmul_i8_f32_i8 = matmul_scalar_i8_f32_i8;
+ initialized = 1;
}
- return impl(m, n, p, A, B, C, scale);
+ return matmul_i8_f32_i8(m, n, p, A, B, C, scale);
}
static int _matmul_i8_f32_u8(size_t m, size_t n, size_t p, const int8_t *A, const float *B, uint8_t *C, double scale) {
- static int (*impl)(size_t, size_t, size_t, const int8_t *, const float *, uint8_t *, double) = NULL;
- static int initialized = 0;
+ static int initialized = 0;
if (!initialized) {
- impl = matmul_scalar_i8_f32_u8;
- initialized = 1;
+ matmul_i8_f32_u8 = matmul_scalar_i8_f32_u8;
+ initialized = 1;
}
- return impl(m, n, p, A, B, C, scale);
+ return matmul_i8_f32_u8(m, n, p, A, B, C, scale);
}
static int _matmul_i8_f64_f32(size_t m, size_t n, size_t p, const int8_t *A, const double *B, float *C, double scale) {
- static int (*impl)(size_t, size_t, size_t, const int8_t *, const double *, float *, double) = NULL;
- static int initialized = 0;
+ static int initialized = 0;
if (!initialized) {
- impl = matmul_scalar_i8_f64_f32;
- initialized = 1;
+ matmul_i8_f64_f32 = matmul_scalar_i8_f64_f32;
+ initialized = 1;
}
- return impl(m, n, p, A, B, C, scale);
+ return matmul_i8_f64_f32(m, n, p, A, B, C, scale);
}
static int _matmul_i8_f64_f64(size_t m, size_t n, size_t p, const int8_t *A, const double *B, double *C, double scale) {
- static int (*impl)(size_t, size_t, size_t, const int8_t *, const double *, double *, double) = NULL;
- static int initialized = 0;
+ static int initialized = 0;
if (!initialized) {
- impl = matmul_scalar_i8_f64_f64;
- initialized = 1;
+ matmul_i8_f64_f64 = matmul_scalar_i8_f64_f64;
+ initialized = 1;
}
- return impl(m, n, p, A, B, C, scale);
+ return matmul_i8_f64_f64(m, n, p, A, B, C, scale);
}
static int _matmul_i8_f64_i8(size_t m, size_t n, size_t p, const int8_t *A, const double *B, int8_t *C, double scale) {
- static int (*impl)(size_t, size_t, size_t, const int8_t *, const double *, int8_t *, double) = NULL;
- static int initialized = 0;
+ static int initialized = 0;
if (!initialized) {
- impl = matmul_scalar_i8_f64_i8;
- initialized = 1;
+ matmul_i8_f64_i8 = matmul_scalar_i8_f64_i8;
+ initialized = 1;
}
- return impl(m, n, p, A, B, C, scale);
+ return matmul_i8_f64_i8(m, n, p, A, B, C, scale);
}
static int _matmul_i8_f64_u8(size_t m, size_t n, size_t p, const int8_t *A, const double *B, uint8_t *C, double scale) {
- static int (*impl)(size_t, size_t, size_t, const int8_t *, const double *, uint8_t *, double) = NULL;
- static int initialized = 0;
+ static int initialized = 0;
if (!initialized) {
- impl = matmul_scalar_i8_f64_u8;
- initialized = 1;
+ matmul_i8_f64_u8 = matmul_scalar_i8_f64_u8;
+ initialized = 1;
}
- return impl(m, n, p, A, B, C, scale);
+ return matmul_i8_f64_u8(m, n, p, A, B, C, scale);
}
static int _matmul_i8_i8_f32(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, float *C, double scale) {
- static int (*impl)(size_t, size_t, size_t, const int8_t *, const int8_t *, float *, double) = NULL;
- static int initialized = 0;
+ static int initialized = 0;
if (!initialized) {
- impl = matmul_scalar_i8_i8_f32;
+ matmul_feature_t feat = matmul_get_feature();
+#ifdef __AVX512F__
+ if (feat & MATMUL_FLAG_AVX512_VNNI)
+ matmul_i8_i8_f32 = matmul_avx512_i8_i8_f32;
+ else
+#endif
+#ifdef __AVX2__
+ if (feat & MATMUL_FLAG_AVXVNNI)
+ matmul_i8_i8_f32 = matmul_avx2_i8_i8_f32;
+ else
+#endif
+ matmul_i8_i8_f32 = matmul_scalar_i8_i8_f32;
initialized = 1;
}
- return impl(m, n, p, A, B, C, scale);
+ return matmul_i8_i8_f32(m, n, p, A, B, C, scale);
}
static int _matmul_i8_i8_f64(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, double *C, double scale) {
- static int (*impl)(size_t, size_t, size_t, const int8_t *, const int8_t *, double *, double) = NULL;
- static int initialized = 0;
+ static int initialized = 0;
if (!initialized) {
- impl = matmul_scalar_i8_i8_f64;
+ matmul_feature_t feat = matmul_get_feature();
+#ifdef __AVX512F__
+ if (feat & MATMUL_FLAG_AVX512_VNNI)
+ matmul_i8_i8_f64 = matmul_avx512_i8_i8_f64;
+ else
+#endif
+#ifdef __AVX2__
+ if (feat & MATMUL_FLAG_AVXVNNI)
+ matmul_i8_i8_f64 = matmul_avx2_i8_i8_f64;
+ else
+#endif
+ matmul_i8_i8_f64 = matmul_scalar_i8_i8_f64;
initialized = 1;
}
- return impl(m, n, p, A, B, C, scale);
+ return matmul_i8_i8_f64(m, n, p, A, B, C, scale);
}
static int _matmul_i8_i8_i8(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, int8_t *C, double scale) {
- static int (*impl)(size_t, size_t, size_t, const int8_t *, const int8_t *, int8_t *, double) = NULL;
- static int initialized = 0;
+ static int initialized = 0;
if (!initialized) {
- impl = matmul_scalar_i8_i8_i8;
+ matmul_feature_t feat = matmul_get_feature();
+#ifdef __AVX512F__
+ if (feat & MATMUL_FLAG_AVX512_VNNI)
+ matmul_i8_i8_i8 = matmul_avx512_i8_i8_i8;
+ else
+#endif
+#ifdef __AVX2__
+ if (feat & MATMUL_FLAG_AVXVNNI)
+ matmul_i8_i8_i8 = matmul_avx2_i8_i8_i8;
+ else
+#endif
+ matmul_i8_i8_i8 = matmul_scalar_i8_i8_i8;
initialized = 1;
}
- return impl(m, n, p, A, B, C, scale);
+ return matmul_i8_i8_i8(m, n, p, A, B, C, scale);
}
static int _matmul_i8_i8_u8(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, uint8_t *C, double scale) {
- static int (*impl)(size_t, size_t, size_t, const int8_t *, const int8_t *, uint8_t *, double) = NULL;
- static int initialized = 0;
+ static int initialized = 0;
if (!initialized) {
- impl = matmul_scalar_i8_i8_u8;
+ matmul_feature_t feat = matmul_get_feature();
+#ifdef __AVX512F__
+ if (feat & MATMUL_FLAG_AVX512_VNNI)
+ matmul_i8_i8_u8 = matmul_avx512_i8_i8_u8;
+ else
+#endif
+#ifdef __AVX2__
+ if (feat & MATMUL_FLAG_AVXVNNI)
+ matmul_i8_i8_u8 = matmul_avx2_i8_i8_u8;
+ else
+#endif
+ matmul_i8_i8_u8 = matmul_scalar_i8_i8_u8;
initialized = 1;
}
- return impl(m, n, p, A, B, C, scale);
+ return matmul_i8_i8_u8(m, n, p, A, B, C, scale);
}
static int _matmul_i8_u8_f32(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, float *C, double scale) {
- static int (*impl)(size_t, size_t, size_t, const int8_t *, const uint8_t *, float *, double) = NULL;
- static int initialized = 0;
+ static int initialized = 0;
if (!initialized) {
- impl = matmul_scalar_i8_u8_f32;
+ matmul_feature_t feat = matmul_get_feature();
+#ifdef __AVX512F__
+ if (feat & MATMUL_FLAG_AVX512_VNNI)
+ matmul_i8_u8_f32 = matmul_avx512_i8_u8_f32;
+ else
+#endif
+#ifdef __AVX2__
+ if (feat & MATMUL_FLAG_AVXVNNI)
+ matmul_i8_u8_f32 = matmul_avx2_i8_u8_f32;
+ else
+#endif
+ matmul_i8_u8_f32 = matmul_scalar_i8_u8_f32;
initialized = 1;
}
- return impl(m, n, p, A, B, C, scale);
+ return matmul_i8_u8_f32(m, n, p, A, B, C, scale);
}
static int _matmul_i8_u8_f64(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, double *C, double scale) {
- static int (*impl)(size_t, size_t, size_t, const int8_t *, const uint8_t *, double *, double) = NULL;
- static int initialized = 0;
+ static int initialized = 0;
if (!initialized) {
- impl = matmul_scalar_i8_u8_f64;
+ matmul_feature_t feat = matmul_get_feature();
+#ifdef __AVX512F__
+ if (feat & MATMUL_FLAG_AVX512_VNNI)
+ matmul_i8_u8_f64 = matmul_avx512_i8_u8_f64;
+ else
+#endif
+#ifdef __AVX2__
+ if (feat & MATMUL_FLAG_AVXVNNI)
+ matmul_i8_u8_f64 = matmul_avx2_i8_u8_f64;
+ else
+#endif
+ matmul_i8_u8_f64 = matmul_scalar_i8_u8_f64;
initialized = 1;
}
- return impl(m, n, p, A, B, C, scale);
+ return matmul_i8_u8_f64(m, n, p, A, B, C, scale);
}
static int _matmul_i8_u8_i8(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, int8_t *C, double scale) {
- static int (*impl)(size_t, size_t, size_t, const int8_t *, const uint8_t *, int8_t *, double) = NULL;
- static int initialized = 0;
+ static int initialized = 0;
if (!initialized) {
- impl = matmul_scalar_i8_u8_i8;
+ matmul_feature_t feat = matmul_get_feature();
+#ifdef __AVX512F__
+ if (feat & MATMUL_FLAG_AVX512_VNNI)
+ matmul_i8_u8_i8 = matmul_avx512_i8_u8_i8;
+ else
+#endif
+#ifdef __AVX2__
+ if (feat & MATMUL_FLAG_AVXVNNI)
+ matmul_i8_u8_i8 = matmul_avx2_i8_u8_i8;
+ else
+#endif
+ matmul_i8_u8_i8 = matmul_scalar_i8_u8_i8;
initialized = 1;
}
- return impl(m, n, p, A, B, C, scale);
+ return matmul_i8_u8_i8(m, n, p, A, B, C, scale);
}
static int _matmul_i8_u8_u8(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, uint8_t *C, double scale) {
- static int (*impl)(size_t, size_t, size_t, const int8_t *, const uint8_t *, uint8_t *, double) = NULL;
- static int initialized = 0;
- if (!initialized) {
- impl = matmul_scalar_i8_u8_u8;
- initialized = 1;
- }
- return impl(m, n, p, A, B, C, scale);
-}
-
-static int _matmul_u8_f32_f32(size_t m, size_t n, size_t p, const uint8_t *A, const float *B, float *C, double scale) {
- static int (*impl)(size_t, size_t, size_t, const uint8_t *, const float *, float *, double) = NULL;
- static int initialized = 0;
+ static int initialized = 0;
if (!initialized) {
- impl = matmul_scalar_u8_f32_f32;
+ matmul_feature_t feat = matmul_get_feature();
+#ifdef __AVX512F__
+ if (feat & MATMUL_FLAG_AVX512_VNNI)
+ matmul_i8_u8_u8 = matmul_avx512_i8_u8_u8;
+ else
+#endif
+#ifdef __AVX2__
+ if (feat & MATMUL_FLAG_AVXVNNI)
+ matmul_i8_u8_u8 = matmul_avx2_i8_u8_u8;
+ else
+#endif
+ matmul_i8_u8_u8 = matmul_scalar_i8_u8_u8;
initialized = 1;
}
- return impl(m, n, p, A, B, C, scale);
+ return matmul_i8_u8_u8(m, n, p, A, B, C, scale);
}
-static int _matmul_u8_f32_f64(size_t m, size_t n, size_t p, const uint8_t *A, const float *B, double *C, double scale) {
- static int (*impl)(size_t, size_t, size_t, const uint8_t *, const float *, double *, double) = NULL;
- static int initialized = 0;
+static int _matmul_u8_i8_f32(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, float *C, double scale) {
+ static int initialized = 0;
if (!initialized) {
- impl = matmul_scalar_u8_f32_f64;
+ matmul_feature_t feat = matmul_get_feature();
+#ifdef __AVX512F__
+ if (feat & MATMUL_FLAG_AVX512_VNNI)
+ matmul_u8_i8_f32 = matmul_avx512_u8_i8_f32;
+ else
+#endif
+#ifdef __AVX2__
+ if (feat & MATMUL_FLAG_AVXVNNI)
+ matmul_u8_i8_f32 = matmul_avx2_u8_i8_f32;
+ else
+#endif
+ matmul_u8_i8_f32 = matmul_scalar_u8_i8_f32;
initialized = 1;
}
- return impl(m, n, p, A, B, C, scale);
+ return matmul_u8_i8_f32(m, n, p, A, B, C, scale);
}
-static int _matmul_u8_f32_i8(size_t m, size_t n, size_t p, const uint8_t *A, const float *B, int8_t *C, double scale) {
- static int (*impl)(size_t, size_t, size_t, const uint8_t *, const float *, int8_t *, double) = NULL;
- static int initialized = 0;
+static int _matmul_u8_i8_f64(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, double *C, double scale) {
+ static int initialized = 0;
if (!initialized) {
- impl = matmul_scalar_u8_f32_i8;
+ matmul_feature_t feat = matmul_get_feature();
+#ifdef __AVX512F__
+ if (feat & MATMUL_FLAG_AVX512_VNNI)
+ matmul_u8_i8_f64 = matmul_avx512_u8_i8_f64;
+ else
+#endif
+#ifdef __AVX2__
+ if (feat & MATMUL_FLAG_AVXVNNI)
+ matmul_u8_i8_f64 = matmul_avx2_u8_i8_f64;
+ else
+#endif
+ matmul_u8_i8_f64 = matmul_scalar_u8_i8_f64;
initialized = 1;
}
- return impl(m, n, p, A, B, C, scale);
+ return matmul_u8_i8_f64(m, n, p, A, B, C, scale);
}
-static int _matmul_u8_f32_u8(size_t m, size_t n, size_t p, const uint8_t *A, const float *B, uint8_t *C, double scale) {
- static int (*impl)(size_t, size_t, size_t, const uint8_t *, const float *, uint8_t *, double) = NULL;
- static int initialized = 0;
+static int _matmul_u8_i8_i8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, int8_t *C, double scale) {
+ static int initialized = 0;
if (!initialized) {
- impl = matmul_scalar_u8_f32_u8;
+ matmul_feature_t feat = matmul_get_feature();
+#ifdef __AVX512F__
+ if (feat & MATMUL_FLAG_AVX512_VNNI)
+ matmul_u8_i8_i8 = matmul_avx512_u8_i8_i8;
+ else
+#endif
+#ifdef __AVX2__
+ if (feat & MATMUL_FLAG_AVXVNNI)
+ matmul_u8_i8_i8 = matmul_avx2_u8_i8_i8;
+ else
+#endif
+ matmul_u8_i8_i8 = matmul_scalar_u8_i8_i8;
initialized = 1;
}
- return impl(m, n, p, A, B, C, scale);
+ return matmul_u8_i8_i8(m, n, p, A, B, C, scale);
}
-static int _matmul_u8_f64_f32(size_t m, size_t n, size_t p, const uint8_t *A, const double *B, float *C, double scale) {
- static int (*impl)(size_t, size_t, size_t, const uint8_t *, const double *, float *, double) = NULL;
- static int initialized = 0;
+static int _matmul_u8_i8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, uint8_t *C, double scale) {
+ static int initialized = 0;
if (!initialized) {
- impl = matmul_scalar_u8_f64_f32;
+ matmul_feature_t feat = matmul_get_feature();
+#ifdef __AVX512F__
+ if (feat & MATMUL_FLAG_AVX512_VNNI)
+ matmul_u8_i8_u8 = matmul_avx512_u8_i8_u8;
+ else
+#endif
+#ifdef __AVX2__
+ if (feat & MATMUL_FLAG_AVXVNNI)
+ matmul_u8_i8_u8 = matmul_avx2_u8_i8_u8;
+ else
+#endif
+ matmul_u8_i8_u8 = matmul_scalar_u8_i8_u8;
initialized = 1;
}
- return impl(m, n, p, A, B, C, scale);
+ return matmul_u8_i8_u8(m, n, p, A, B, C, scale);
}
-static int _matmul_u8_f64_f64(size_t m, size_t n, size_t p, const uint8_t *A, const double *B, double *C,
- double scale) {
- static int (*impl)(size_t, size_t, size_t, const uint8_t *, const double *, double *, double) = NULL;
- static int initialized = 0;
+static int _matmul_u8_u8_f32(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, float *C, double scale) {
+ static int initialized = 0;
if (!initialized) {
- impl = matmul_scalar_u8_f64_f64;
+ matmul_feature_t feat = matmul_get_feature();
+#ifdef __AVX512F__
+ if (feat & MATMUL_FLAG_AVX512_VNNI)
+ matmul_u8_u8_f32 = matmul_avx512_u8_u8_f32;
+ else
+#endif
+#ifdef __AVX2__
+ if (feat & MATMUL_FLAG_AVXVNNI)
+ matmul_u8_u8_f32 = matmul_avx2_u8_u8_f32;
+ else
+#endif
+ matmul_u8_u8_f32 = matmul_scalar_u8_u8_f32;
initialized = 1;
}
- return impl(m, n, p, A, B, C, scale);
+ return matmul_u8_u8_f32(m, n, p, A, B, C, scale);
}
-static int _matmul_u8_f64_i8(size_t m, size_t n, size_t p, const uint8_t *A, const double *B, int8_t *C, double scale) {
- static int (*impl)(size_t, size_t, size_t, const uint8_t *, const double *, int8_t *, double) = NULL;
- static int initialized = 0;
+static int _matmul_u8_u8_f64(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, double *C,
+ double scale) {
+ static int initialized = 0;
if (!initialized) {
- impl = matmul_scalar_u8_f64_i8;
+ matmul_feature_t feat = matmul_get_feature();
+#ifdef __AVX512F__
+ if (feat & MATMUL_FLAG_AVX512_VNNI)
+ matmul_u8_u8_f64 = matmul_avx512_u8_u8_f64;
+ else
+#endif
+#ifdef __AVX2__
+ if (feat & MATMUL_FLAG_AVXVNNI)
+ matmul_u8_u8_f64 = matmul_avx2_u8_u8_f64;
+ else
+#endif
+ matmul_u8_u8_f64 = matmul_scalar_u8_u8_f64;
initialized = 1;
}
- return impl(m, n, p, A, B, C, scale);
+ return matmul_u8_u8_f64(m, n, p, A, B, C, scale);
}
-static int _matmul_u8_f64_u8(size_t m, size_t n, size_t p, const uint8_t *A, const double *B, uint8_t *C,
- double scale) {
- static int (*impl)(size_t, size_t, size_t, const uint8_t *, const double *, uint8_t *, double) = NULL;
- static int initialized = 0;
+static int _matmul_u8_u8_i8(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, int8_t *C, double scale) {
+ static int initialized = 0;
if (!initialized) {
- impl = matmul_scalar_u8_f64_u8;
+ matmul_feature_t feat = matmul_get_feature();
+#ifdef __AVX512F__
+ if (feat & MATMUL_FLAG_AVX512_VNNI)
+ matmul_u8_u8_i8 = matmul_avx512_u8_u8_i8;
+ else
+#endif
+#ifdef __AVX2__
+ if (feat & MATMUL_FLAG_AVXVNNI)
+ matmul_u8_u8_i8 = matmul_avx2_u8_u8_i8;
+ else
+#endif
+ matmul_u8_u8_i8 = matmul_scalar_u8_u8_i8;
initialized = 1;
}
- return impl(m, n, p, A, B, C, scale);
+ return matmul_u8_u8_i8(m, n, p, A, B, C, scale);
}
-static int _matmul_u8_i8_f32(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, float *C, double scale) {
- static int (*impl)(size_t, size_t, size_t, const uint8_t *, const int8_t *, float *, double) = NULL;
- static int initialized = 0;
+static int _matmul_u8_u8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, uint8_t *C,
+ double scale) {
+ static int initialized = 0;
if (!initialized) {
- impl = matmul_scalar_u8_i8_f32;
+ matmul_feature_t feat = matmul_get_feature();
+#ifdef __AVX512F__
+ if (feat & MATMUL_FLAG_AVX512_VNNI)
+ matmul_u8_u8_u8 = matmul_avx512_u8_u8_u8;
+ else
+#endif
+#ifdef __AVX2__
+ if (feat & MATMUL_FLAG_AVXVNNI)
+ matmul_u8_u8_u8 = matmul_avx2_u8_u8_u8;
+ else
+#endif
+ matmul_u8_u8_u8 = matmul_scalar_u8_u8_u8;
initialized = 1;
}
- return impl(m, n, p, A, B, C, scale);
+ return matmul_u8_u8_u8(m, n, p, A, B, C, scale);
}
-static int _matmul_u8_i8_f64(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, double *C, double scale) {
- static int (*impl)(size_t, size_t, size_t, const uint8_t *, const int8_t *, double *, double) = NULL;
- static int initialized = 0;
+static int _matmul_u8_f32_f32(size_t m, size_t n, size_t p, const uint8_t *A, const float *B, float *C, double scale) {
+ static int initialized = 0;
if (!initialized) {
- impl = matmul_scalar_u8_i8_f64;
- initialized = 1;
+ matmul_u8_f32_f32 = matmul_scalar_u8_f32_f32;
+ initialized = 1;
}
- return impl(m, n, p, A, B, C, scale);
+ return matmul_u8_f32_f32(m, n, p, A, B, C, scale);
}
-static int _matmul_u8_i8_i8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, int8_t *C, double scale) {
- static int (*impl)(size_t, size_t, size_t, const uint8_t *, const int8_t *, int8_t *, double) = NULL;
- static int initialized = 0;
+static int _matmul_u8_f32_f64(size_t m, size_t n, size_t p, const uint8_t *A, const float *B, double *C, double scale) {
+ static int initialized = 0;
if (!initialized) {
- impl = matmul_scalar_u8_i8_i8;
- initialized = 1;
+ matmul_u8_f32_f64 = matmul_scalar_u8_f32_f64;
+ initialized = 1;
}
- return impl(m, n, p, A, B, C, scale);
+ return matmul_u8_f32_f64(m, n, p, A, B, C, scale);
}
-static int _matmul_u8_i8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, uint8_t *C, double scale) {
- static int (*impl)(size_t, size_t, size_t, const uint8_t *, const int8_t *, uint8_t *, double) = NULL;
- static int initialized = 0;
+static int _matmul_u8_f32_i8(size_t m, size_t n, size_t p, const uint8_t *A, const float *B, int8_t *C, double scale) {
+ static int initialized = 0;
if (!initialized) {
- impl = matmul_scalar_u8_i8_u8;
- initialized = 1;
+ matmul_u8_f32_i8 = matmul_scalar_u8_f32_i8;
+ initialized = 1;
}
- return impl(m, n, p, A, B, C, scale);
+ return matmul_u8_f32_i8(m, n, p, A, B, C, scale);
}
-static int _matmul_u8_u8_f32(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, float *C, double scale) {
- static int (*impl)(size_t, size_t, size_t, const uint8_t *, const uint8_t *, float *, double) = NULL;
- static int initialized = 0;
+static int _matmul_u8_f32_u8(size_t m, size_t n, size_t p, const uint8_t *A, const float *B, uint8_t *C, double scale) {
+ static int initialized = 0;
if (!initialized) {
- impl = matmul_scalar_u8_u8_f32;
- initialized = 1;
+ matmul_u8_f32_u8 = matmul_scalar_u8_f32_u8;
+ initialized = 1;
}
- return impl(m, n, p, A, B, C, scale);
+ return matmul_u8_f32_u8(m, n, p, A, B, C, scale);
}
-static int _matmul_u8_u8_f64(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, double *C,
- double scale) {
- static int (*impl)(size_t, size_t, size_t, const uint8_t *, const uint8_t *, double *, double) = NULL;
- static int initialized = 0;
+static int _matmul_u8_f64_f32(size_t m, size_t n, size_t p, const uint8_t *A, const double *B, float *C, double scale) {
+ static int initialized = 0;
if (!initialized) {
- impl = matmul_scalar_u8_u8_f64;
- initialized = 1;
+ matmul_u8_f64_f32 = matmul_scalar_u8_f64_f32;
+ initialized = 1;
}
- return impl(m, n, p, A, B, C, scale);
+ return matmul_u8_f64_f32(m, n, p, A, B, C, scale);
}
-static int _matmul_u8_u8_i8(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, int8_t *C, double scale) {
- static int (*impl)(size_t, size_t, size_t, const uint8_t *, const uint8_t *, int8_t *, double) = NULL;
- static int initialized = 0;
+static int _matmul_u8_f64_f64(size_t m, size_t n, size_t p, const uint8_t *A, const double *B, double *C,
+ double scale) {
+ static int initialized = 0;
if (!initialized) {
- impl = matmul_scalar_u8_u8_i8;
- initialized = 1;
+ matmul_u8_f64_f64 = matmul_scalar_u8_f64_f64;
+ initialized = 1;
}
- return impl(m, n, p, A, B, C, scale);
+ return matmul_u8_f64_f64(m, n, p, A, B, C, scale);
}
-static int _matmul_u8_u8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, uint8_t *C,
- double scale) {
- static int (*impl)(size_t, size_t, size_t, const uint8_t *, const uint8_t *, uint8_t *, double) = NULL;
- static int initialized = 0;
+static int _matmul_u8_f64_i8(size_t m, size_t n, size_t p, const uint8_t *A, const double *B, int8_t *C, double scale) {
+ static int initialized = 0;
if (!initialized) {
- impl = matmul_scalar_u8_u8_u8;
- initialized = 1;
+ matmul_u8_f64_i8 = matmul_scalar_u8_f64_i8;
+ initialized = 1;
}
- return impl(m, n, p, A, B, C, scale);
+ return matmul_u8_f64_i8(m, n, p, A, B, C, scale);
}
-int matmul_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const float *B, float *C, double scale) {
- return _matmul_f32_f32_f32(m, n, p, A, B, C, scale);
-}
-int matmul_f32_f32_f64(size_t m, size_t n, size_t p, const float *A, const float *B, double *C, double scale) {
- return _matmul_f32_f32_f64(m, n, p, A, B, C, scale);
-}
-int matmul_f32_f32_i8(size_t m, size_t n, size_t p, const float *A, const float *B, int8_t *C, double scale) {
- return _matmul_f32_f32_i8(m, n, p, A, B, C, scale);
-}
-int matmul_f32_f32_u8(size_t m, size_t n, size_t p, const float *A, const float *B, uint8_t *C, double scale) {
- return _matmul_f32_f32_u8(m, n, p, A, B, C, scale);
-}
-int matmul_f32_f64_f32(size_t m, size_t n, size_t p, const float *A, const double *B, float *C, double scale) {
- return _matmul_f32_f64_f32(m, n, p, A, B, C, scale);
-}
-int matmul_f32_f64_f64(size_t m, size_t n, size_t p, const float *A, const double *B, double *C, double scale) {
- return _matmul_f32_f64_f64(m, n, p, A, B, C, scale);
-}
-int matmul_f32_f64_i8(size_t m, size_t n, size_t p, const float *A, const double *B, int8_t *C, double scale) {
- return _matmul_f32_f64_i8(m, n, p, A, B, C, scale);
-}
-int matmul_f32_f64_u8(size_t m, size_t n, size_t p, const float *A, const double *B, uint8_t *C, double scale) {
- return _matmul_f32_f64_u8(m, n, p, A, B, C, scale);
-}
-int matmul_f32_i8_f32(size_t m, size_t n, size_t p, const float *A, const int8_t *B, float *C, double scale) {
- return _matmul_f32_i8_f32(m, n, p, A, B, C, scale);
-}
-int matmul_f32_i8_f64(size_t m, size_t n, size_t p, const float *A, const int8_t *B, double *C, double scale) {
- return _matmul_f32_i8_f64(m, n, p, A, B, C, scale);
-}
-int matmul_f32_i8_i8(size_t m, size_t n, size_t p, const float *A, const int8_t *B, int8_t *C, double scale) {
- return _matmul_f32_i8_i8(m, n, p, A, B, C, scale);
-}
-int matmul_f32_i8_u8(size_t m, size_t n, size_t p, const float *A, const int8_t *B, uint8_t *C, double scale) {
- return _matmul_f32_i8_u8(m, n, p, A, B, C, scale);
-}
-int matmul_f32_u8_f32(size_t m, size_t n, size_t p, const float *A, const uint8_t *B, float *C, double scale) {
- return _matmul_f32_u8_f32(m, n, p, A, B, C, scale);
-}
-int matmul_f32_u8_f64(size_t m, size_t n, size_t p, const float *A, const uint8_t *B, double *C, double scale) {
- return _matmul_f32_u8_f64(m, n, p, A, B, C, scale);
-}
-int matmul_f32_u8_i8(size_t m, size_t n, size_t p, const float *A, const uint8_t *B, int8_t *C, double scale) {
- return _matmul_f32_u8_i8(m, n, p, A, B, C, scale);
-}
-int matmul_f32_u8_u8(size_t m, size_t n, size_t p, const float *A, const uint8_t *B, uint8_t *C, double scale) {
- return _matmul_f32_u8_u8(m, n, p, A, B, C, scale);
-}
-int matmul_f64_f32_f32(size_t m, size_t n, size_t p, const double *A, const float *B, float *C, double scale) {
- return _matmul_f64_f32_f32(m, n, p, A, B, C, scale);
-}
-int matmul_f64_f32_f64(size_t m, size_t n, size_t p, const double *A, const float *B, double *C, double scale) {
- return _matmul_f64_f32_f64(m, n, p, A, B, C, scale);
-}
-int matmul_f64_f32_i8(size_t m, size_t n, size_t p, const double *A, const float *B, int8_t *C, double scale) {
- return _matmul_f64_f32_i8(m, n, p, A, B, C, scale);
-}
-int matmul_f64_f32_u8(size_t m, size_t n, size_t p, const double *A, const float *B, uint8_t *C, double scale) {
- return _matmul_f64_f32_u8(m, n, p, A, B, C, scale);
-}
-int matmul_f64_f64_f32(size_t m, size_t n, size_t p, const double *A, const double *B, float *C, double scale) {
- return _matmul_f64_f64_f32(m, n, p, A, B, C, scale);
-}
-int matmul_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, const double *B, double *C, double scale) {
- return _matmul_f64_f64_f64(m, n, p, A, B, C, scale);
-}
-int matmul_f64_f64_i8(size_t m, size_t n, size_t p, const double *A, const double *B, int8_t *C, double scale) {
- return _matmul_f64_f64_i8(m, n, p, A, B, C, scale);
-}
-int matmul_f64_f64_u8(size_t m, size_t n, size_t p, const double *A, const double *B, uint8_t *C, double scale) {
- return _matmul_f64_f64_u8(m, n, p, A, B, C, scale);
-}
-int matmul_f64_i8_f32(size_t m, size_t n, size_t p, const double *A, const int8_t *B, float *C, double scale) {
- return _matmul_f64_i8_f32(m, n, p, A, B, C, scale);
-}
-int matmul_f64_i8_f64(size_t m, size_t n, size_t p, const double *A, const int8_t *B, double *C, double scale) {
- return _matmul_f64_i8_f64(m, n, p, A, B, C, scale);
-}
-int matmul_f64_i8_i8(size_t m, size_t n, size_t p, const double *A, const int8_t *B, int8_t *C, double scale) {
- return _matmul_f64_i8_i8(m, n, p, A, B, C, scale);
-}
-int matmul_f64_i8_u8(size_t m, size_t n, size_t p, const double *A, const int8_t *B, uint8_t *C, double scale) {
- return _matmul_f64_i8_u8(m, n, p, A, B, C, scale);
-}
-int matmul_f64_u8_f32(size_t m, size_t n, size_t p, const double *A, const uint8_t *B, float *C, double scale) {
- return _matmul_f64_u8_f32(m, n, p, A, B, C, scale);
-}
-int matmul_f64_u8_f64(size_t m, size_t n, size_t p, const double *A, const uint8_t *B, double *C, double scale) {
- return _matmul_f64_u8_f64(m, n, p, A, B, C, scale);
-}
-int matmul_f64_u8_i8(size_t m, size_t n, size_t p, const double *A, const uint8_t *B, int8_t *C, double scale) {
- return _matmul_f64_u8_i8(m, n, p, A, B, C, scale);
-}
-int matmul_f64_u8_u8(size_t m, size_t n, size_t p, const double *A, const uint8_t *B, uint8_t *C, double scale) {
- return _matmul_f64_u8_u8(m, n, p, A, B, C, scale);
-}
-int matmul_i8_f32_f32(size_t m, size_t n, size_t p, const int8_t *A, const float *B, float *C, double scale) {
- return _matmul_i8_f32_f32(m, n, p, A, B, C, scale);
-}
-int matmul_i8_f32_f64(size_t m, size_t n, size_t p, const int8_t *A, const float *B, double *C, double scale) {
- return _matmul_i8_f32_f64(m, n, p, A, B, C, scale);
-}
-int matmul_i8_f32_i8(size_t m, size_t n, size_t p, const int8_t *A, const float *B, int8_t *C, double scale) {
- return _matmul_i8_f32_i8(m, n, p, A, B, C, scale);
-}
-int matmul_i8_f32_u8(size_t m, size_t n, size_t p, const int8_t *A, const float *B, uint8_t *C, double scale) {
- return _matmul_i8_f32_u8(m, n, p, A, B, C, scale);
-}
-int matmul_i8_f64_f32(size_t m, size_t n, size_t p, const int8_t *A, const double *B, float *C, double scale) {
- return _matmul_i8_f64_f32(m, n, p, A, B, C, scale);
-}
-int matmul_i8_f64_f64(size_t m, size_t n, size_t p, const int8_t *A, const double *B, double *C, double scale) {
- return _matmul_i8_f64_f64(m, n, p, A, B, C, scale);
-}
-int matmul_i8_f64_i8(size_t m, size_t n, size_t p, const int8_t *A, const double *B, int8_t *C, double scale) {
- return _matmul_i8_f64_i8(m, n, p, A, B, C, scale);
-}
-int matmul_i8_f64_u8(size_t m, size_t n, size_t p, const int8_t *A, const double *B, uint8_t *C, double scale) {
- return _matmul_i8_f64_u8(m, n, p, A, B, C, scale);
-}
-int matmul_i8_i8_f32(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, float *C, double scale) {
- return _matmul_i8_i8_f32(m, n, p, A, B, C, scale);
-}
-int matmul_i8_i8_f64(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, double *C, double scale) {
- return _matmul_i8_i8_f64(m, n, p, A, B, C, scale);
-}
-int matmul_i8_i8_i8(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, int8_t *C, double scale) {
- return _matmul_i8_i8_i8(m, n, p, A, B, C, scale);
-}
-int matmul_i8_i8_u8(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, uint8_t *C, double scale) {
- return _matmul_i8_i8_u8(m, n, p, A, B, C, scale);
-}
-int matmul_i8_u8_f32(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, float *C, double scale) {
- return _matmul_i8_u8_f32(m, n, p, A, B, C, scale);
-}
-int matmul_i8_u8_f64(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, double *C, double scale) {
- return _matmul_i8_u8_f64(m, n, p, A, B, C, scale);
-}
-int matmul_i8_u8_i8(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, int8_t *C, double scale) {
- return _matmul_i8_u8_i8(m, n, p, A, B, C, scale);
-}
-int matmul_i8_u8_u8(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, uint8_t *C, double scale) {
- return _matmul_i8_u8_u8(m, n, p, A, B, C, scale);
-}
-int matmul_u8_f32_f32(size_t m, size_t n, size_t p, const uint8_t *A, const float *B, float *C, double scale) {
- return _matmul_u8_f32_f32(m, n, p, A, B, C, scale);
-}
-int matmul_u8_f32_f64(size_t m, size_t n, size_t p, const uint8_t *A, const float *B, double *C, double scale) {
- return _matmul_u8_f32_f64(m, n, p, A, B, C, scale);
-}
-int matmul_u8_f32_i8(size_t m, size_t n, size_t p, const uint8_t *A, const float *B, int8_t *C, double scale) {
- return _matmul_u8_f32_i8(m, n, p, A, B, C, scale);
-}
-int matmul_u8_f32_u8(size_t m, size_t n, size_t p, const uint8_t *A, const float *B, uint8_t *C, double scale) {
- return _matmul_u8_f32_u8(m, n, p, A, B, C, scale);
-}
-int matmul_u8_f64_f32(size_t m, size_t n, size_t p, const uint8_t *A, const double *B, float *C, double scale) {
- return _matmul_u8_f64_f32(m, n, p, A, B, C, scale);
-}
-int matmul_u8_f64_f64(size_t m, size_t n, size_t p, const uint8_t *A, const double *B, double *C, double scale) {
- return _matmul_u8_f64_f64(m, n, p, A, B, C, scale);
-}
-int matmul_u8_f64_i8(size_t m, size_t n, size_t p, const uint8_t *A, const double *B, int8_t *C, double scale) {
- return _matmul_u8_f64_i8(m, n, p, A, B, C, scale);
-}
-int matmul_u8_f64_u8(size_t m, size_t n, size_t p, const uint8_t *A, const double *B, uint8_t *C, double scale) {
- return _matmul_u8_f64_u8(m, n, p, A, B, C, scale);
-}
-int matmul_u8_i8_f32(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, float *C, double scale) {
- return _matmul_u8_i8_f32(m, n, p, A, B, C, scale);
-}
-int matmul_u8_i8_f64(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, double *C, double scale) {
- return _matmul_u8_i8_f64(m, n, p, A, B, C, scale);
-}
-int matmul_u8_i8_i8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, int8_t *C, double scale) {
- return _matmul_u8_i8_i8(m, n, p, A, B, C, scale);
-}
-int matmul_u8_i8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, uint8_t *C, double scale) {
- return _matmul_u8_i8_u8(m, n, p, A, B, C, scale);
-}
-int matmul_u8_u8_f32(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, float *C, double scale) {
- return _matmul_u8_u8_f32(m, n, p, A, B, C, scale);
-}
-int matmul_u8_u8_f64(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, double *C, double scale) {
- return _matmul_u8_u8_f64(m, n, p, A, B, C, scale);
-}
-int matmul_u8_u8_i8(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, int8_t *C, double scale) {
- return _matmul_u8_u8_i8(m, n, p, A, B, C, scale);
-}
-int matmul_u8_u8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, uint8_t *C, double scale) {
- return _matmul_u8_u8_u8(m, n, p, A, B, C, scale);
-}
+static int _matmul_u8_f64_u8(size_t m, size_t n, size_t p, const uint8_t *A, const double *B, uint8_t *C,
+ double scale) {
+ static int initialized = 0;
+ if (!initialized) {
+ matmul_u8_f64_u8 = matmul_scalar_u8_f64_u8;
+ initialized = 1;
+ }
+ return matmul_u8_f64_u8(m, n, p, A, B, C, scale);
+}
+
+int (*matmul_f32_f32_f32)(size_t, size_t, size_t, const float *, const float *, float *, double) = _matmul_f32_f32_f32;
+int (*matmul_f32_f32_f64)(size_t, size_t, size_t, const float *, const float *, double *, double) = _matmul_f32_f32_f64;
+int (*matmul_f32_f32_i8)(size_t, size_t, size_t, const float *, const float *, int8_t *, double) = _matmul_f32_f32_i8;
+int (*matmul_f32_f32_u8)(size_t, size_t, size_t, const float *, const float *, uint8_t *, double) = _matmul_f32_f32_u8;
+int (*matmul_f32_f64_f32)(size_t, size_t, size_t, const float *, const double *, float *, double) = _matmul_f32_f64_f32;
+int (*matmul_f32_f64_f64)(size_t, size_t, size_t, const float *, const double *, double *,
+ double) = _matmul_f32_f64_f64;
+int (*matmul_f32_f64_i8)(size_t, size_t, size_t, const float *, const double *, int8_t *, double) = _matmul_f32_f64_i8;
+int (*matmul_f32_f64_u8)(size_t, size_t, size_t, const float *, const double *, uint8_t *, double) = _matmul_f32_f64_u8;
+int (*matmul_f32_i8_f32)(size_t, size_t, size_t, const float *, const int8_t *, float *, double) = _matmul_f32_i8_f32;
+int (*matmul_f32_i8_f64)(size_t, size_t, size_t, const float *, const int8_t *, double *, double) = _matmul_f32_i8_f64;
+int (*matmul_f32_i8_i8)(size_t, size_t, size_t, const float *, const int8_t *, int8_t *, double) = _matmul_f32_i8_i8;
+int (*matmul_f32_i8_u8)(size_t, size_t, size_t, const float *, const int8_t *, uint8_t *, double) = _matmul_f32_i8_u8;
+int (*matmul_f32_u8_f32)(size_t, size_t, size_t, const float *, const uint8_t *, float *, double) = _matmul_f32_u8_f32;
+int (*matmul_f32_u8_f64)(size_t, size_t, size_t, const float *, const uint8_t *, double *, double) = _matmul_f32_u8_f64;
+int (*matmul_f32_u8_i8)(size_t, size_t, size_t, const float *, const uint8_t *, int8_t *, double) = _matmul_f32_u8_i8;
+int (*matmul_f32_u8_u8)(size_t, size_t, size_t, const float *, const uint8_t *, uint8_t *, double) = _matmul_f32_u8_u8;
+int (*matmul_f64_f32_f32)(size_t, size_t, size_t, const double *, const float *, float *, double) = _matmul_f64_f32_f32;
+int (*matmul_f64_f32_f64)(size_t, size_t, size_t, const double *, const float *, double *,
+ double) = _matmul_f64_f32_f64;
+int (*matmul_f64_f32_i8)(size_t, size_t, size_t, const double *, const float *, int8_t *, double) = _matmul_f64_f32_i8;
+int (*matmul_f64_f32_u8)(size_t, size_t, size_t, const double *, const float *, uint8_t *, double) = _matmul_f64_f32_u8;
+int (*matmul_f64_f64_f32)(size_t, size_t, size_t, const double *, const double *, float *,
+ double) = _matmul_f64_f64_f32;
+int (*matmul_f64_f64_f64)(size_t, size_t, size_t, const double *, const double *, double *,
+ double) = _matmul_f64_f64_f64;
+int (*matmul_f64_f64_i8)(size_t, size_t, size_t, const double *, const double *, int8_t *, double) = _matmul_f64_f64_i8;
+int (*matmul_f64_f64_u8)(size_t, size_t, size_t, const double *, const double *, uint8_t *,
+ double) = _matmul_f64_f64_u8;
+int (*matmul_f64_i8_f32)(size_t, size_t, size_t, const double *, const int8_t *, float *, double) = _matmul_f64_i8_f32;
+int (*matmul_f64_i8_f64)(size_t, size_t, size_t, const double *, const int8_t *, double *, double) = _matmul_f64_i8_f64;
+int (*matmul_f64_i8_i8)(size_t, size_t, size_t, const double *, const int8_t *, int8_t *, double) = _matmul_f64_i8_i8;
+int (*matmul_f64_i8_u8)(size_t, size_t, size_t, const double *, const int8_t *, uint8_t *, double) = _matmul_f64_i8_u8;
+int (*matmul_f64_u8_f32)(size_t, size_t, size_t, const double *, const uint8_t *, float *, double) = _matmul_f64_u8_f32;
+int (*matmul_f64_u8_f64)(size_t, size_t, size_t, const double *, const uint8_t *, double *,
+ double) = _matmul_f64_u8_f64;
+int (*matmul_f64_u8_i8)(size_t, size_t, size_t, const double *, const uint8_t *, int8_t *, double) = _matmul_f64_u8_i8;
+int (*matmul_f64_u8_u8)(size_t, size_t, size_t, const double *, const uint8_t *, uint8_t *, double) = _matmul_f64_u8_u8;
+int (*matmul_i8_f32_f32)(size_t, size_t, size_t, const int8_t *, const float *, float *, double) = _matmul_i8_f32_f32;
+int (*matmul_i8_f32_f64)(size_t, size_t, size_t, const int8_t *, const float *, double *, double) = _matmul_i8_f32_f64;
+int (*matmul_i8_f32_i8)(size_t, size_t, size_t, const int8_t *, const float *, int8_t *, double) = _matmul_i8_f32_i8;
+int (*matmul_i8_f32_u8)(size_t, size_t, size_t, const int8_t *, const float *, uint8_t *, double) = _matmul_i8_f32_u8;
+int (*matmul_i8_f64_f32)(size_t, size_t, size_t, const int8_t *, const double *, float *, double) = _matmul_i8_f64_f32;
+int (*matmul_i8_f64_f64)(size_t, size_t, size_t, const int8_t *, const double *, double *, double) = _matmul_i8_f64_f64;
+int (*matmul_i8_f64_i8)(size_t, size_t, size_t, const int8_t *, const double *, int8_t *, double) = _matmul_i8_f64_i8;
+int (*matmul_i8_f64_u8)(size_t, size_t, size_t, const int8_t *, const double *, uint8_t *, double) = _matmul_i8_f64_u8;
+int (*matmul_i8_i8_f32)(size_t, size_t, size_t, const int8_t *, const int8_t *, float *, double) = _matmul_i8_i8_f32;
+int (*matmul_i8_i8_f64)(size_t, size_t, size_t, const int8_t *, const int8_t *, double *, double) = _matmul_i8_i8_f64;
+int (*matmul_i8_i8_i8)(size_t, size_t, size_t, const int8_t *, const int8_t *, int8_t *, double) = _matmul_i8_i8_i8;
+int (*matmul_i8_i8_u8)(size_t, size_t, size_t, const int8_t *, const int8_t *, uint8_t *, double) = _matmul_i8_i8_u8;
+int (*matmul_i8_u8_f32)(size_t, size_t, size_t, const int8_t *, const uint8_t *, float *, double) = _matmul_i8_u8_f32;
+int (*matmul_i8_u8_f64)(size_t, size_t, size_t, const int8_t *, const uint8_t *, double *, double) = _matmul_i8_u8_f64;
+int (*matmul_i8_u8_i8)(size_t, size_t, size_t, const int8_t *, const uint8_t *, int8_t *, double) = _matmul_i8_u8_i8;
+int (*matmul_i8_u8_u8)(size_t, size_t, size_t, const int8_t *, const uint8_t *, uint8_t *, double) = _matmul_i8_u8_u8;
+int (*matmul_u8_f32_f32)(size_t, size_t, size_t, const uint8_t *, const float *, float *, double) = _matmul_u8_f32_f32;
+int (*matmul_u8_f32_f64)(size_t, size_t, size_t, const uint8_t *, const float *, double *, double) = _matmul_u8_f32_f64;
+int (*matmul_u8_f32_i8)(size_t, size_t, size_t, const uint8_t *, const float *, int8_t *, double) = _matmul_u8_f32_i8;
+int (*matmul_u8_f32_u8)(size_t, size_t, size_t, const uint8_t *, const float *, uint8_t *, double) = _matmul_u8_f32_u8;
+int (*matmul_u8_f64_f32)(size_t, size_t, size_t, const uint8_t *, const double *, float *, double) = _matmul_u8_f64_f32;
+int (*matmul_u8_f64_f64)(size_t, size_t, size_t, const uint8_t *, const double *, double *,
+ double) = _matmul_u8_f64_f64;
+int (*matmul_u8_f64_i8)(size_t, size_t, size_t, const uint8_t *, const double *, int8_t *, double) = _matmul_u8_f64_i8;
+int (*matmul_u8_f64_u8)(size_t, size_t, size_t, const uint8_t *, const double *, uint8_t *, double) = _matmul_u8_f64_u8;
+int (*matmul_u8_i8_f32)(size_t, size_t, size_t, const uint8_t *, const int8_t *, float *, double) = _matmul_u8_i8_f32;
+int (*matmul_u8_i8_f64)(size_t, size_t, size_t, const uint8_t *, const int8_t *, double *, double) = _matmul_u8_i8_f64;
+int (*matmul_u8_i8_i8)(size_t, size_t, size_t, const uint8_t *, const int8_t *, int8_t *, double) = _matmul_u8_i8_i8;
+int (*matmul_u8_i8_u8)(size_t, size_t, size_t, const uint8_t *, const int8_t *, uint8_t *, double) = _matmul_u8_i8_u8;
+int (*matmul_u8_u8_f32)(size_t, size_t, size_t, const uint8_t *, const uint8_t *, float *, double) = _matmul_u8_u8_f32;
+int (*matmul_u8_u8_f64)(size_t, size_t, size_t, const uint8_t *, const uint8_t *, double *, double) = _matmul_u8_u8_f64;
+int (*matmul_u8_u8_i8)(size_t, size_t, size_t, const uint8_t *, const uint8_t *, int8_t *, double) = _matmul_u8_u8_i8;
+int (*matmul_u8_u8_u8)(size_t, size_t, size_t, const uint8_t *, const uint8_t *, uint8_t *, double) = _matmul_u8_u8_u8;
diff --git a/src/matmul.h b/src/matmul.h
@@ -56,70 +56,70 @@ typedef uint32_t matmul_feature_t;
matmul_feature_t matmul_get_feature(void);
const char *matmul_get_feature_name(matmul_feature_t feat);
-int matmul_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const float *B, float *C, double scale);
-int matmul_f32_f32_f64(size_t m, size_t n, size_t p, const float *A, const float *B, double *C, double scale);
-int matmul_f32_f32_i8(size_t m, size_t n, size_t p, const float *A, const float *B, int8_t *C, double scale);
-int matmul_f32_f32_u8(size_t m, size_t n, size_t p, const float *A, const float *B, uint8_t *C, double scale);
-int matmul_f32_f64_f32(size_t m, size_t n, size_t p, const float *A, const double *B, float *C, double scale);
-int matmul_f32_f64_f64(size_t m, size_t n, size_t p, const float *A, const double *B, double *C, double scale);
-int matmul_f32_f64_i8(size_t m, size_t n, size_t p, const float *A, const double *B, int8_t *C, double scale);
-int matmul_f32_f64_u8(size_t m, size_t n, size_t p, const float *A, const double *B, uint8_t *C, double scale);
-int matmul_f32_i8_f32(size_t m, size_t n, size_t p, const float *A, const int8_t *B, float *C, double scale);
-int matmul_f32_i8_f64(size_t m, size_t n, size_t p, const float *A, const int8_t *B, double *C, double scale);
-int matmul_f32_i8_i8(size_t m, size_t n, size_t p, const float *A, const int8_t *B, int8_t *C, double scale);
-int matmul_f32_i8_u8(size_t m, size_t n, size_t p, const float *A, const int8_t *B, uint8_t *C, double scale);
-int matmul_f32_u8_f32(size_t m, size_t n, size_t p, const float *A, const uint8_t *B, float *C, double scale);
-int matmul_f32_u8_f64(size_t m, size_t n, size_t p, const float *A, const uint8_t *B, double *C, double scale);
-int matmul_f32_u8_i8(size_t m, size_t n, size_t p, const float *A, const uint8_t *B, int8_t *C, double scale);
-int matmul_f32_u8_u8(size_t m, size_t n, size_t p, const float *A, const uint8_t *B, uint8_t *C, double scale);
-int matmul_f64_f32_f32(size_t m, size_t n, size_t p, const double *A, const float *B, float *C, double scale);
-int matmul_f64_f32_f64(size_t m, size_t n, size_t p, const double *A, const float *B, double *C, double scale);
-int matmul_f64_f32_i8(size_t m, size_t n, size_t p, const double *A, const float *B, int8_t *C, double scale);
-int matmul_f64_f32_u8(size_t m, size_t n, size_t p, const double *A, const float *B, uint8_t *C, double scale);
-int matmul_f64_f64_f32(size_t m, size_t n, size_t p, const double *A, const double *B, float *C, double scale);
-int matmul_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, const double *B, double *C, double scale);
-int matmul_f64_f64_i8(size_t m, size_t n, size_t p, const double *A, const double *B, int8_t *C, double scale);
-int matmul_f64_f64_u8(size_t m, size_t n, size_t p, const double *A, const double *B, uint8_t *C, double scale);
-int matmul_f64_i8_f32(size_t m, size_t n, size_t p, const double *A, const int8_t *B, float *C, double scale);
-int matmul_f64_i8_f64(size_t m, size_t n, size_t p, const double *A, const int8_t *B, double *C, double scale);
-int matmul_f64_i8_i8(size_t m, size_t n, size_t p, const double *A, const int8_t *B, int8_t *C, double scale);
-int matmul_f64_i8_u8(size_t m, size_t n, size_t p, const double *A, const int8_t *B, uint8_t *C, double scale);
-int matmul_f64_u8_f32(size_t m, size_t n, size_t p, const double *A, const uint8_t *B, float *C, double scale);
-int matmul_f64_u8_f64(size_t m, size_t n, size_t p, const double *A, const uint8_t *B, double *C, double scale);
-int matmul_f64_u8_i8(size_t m, size_t n, size_t p, const double *A, const uint8_t *B, int8_t *C, double scale);
-int matmul_f64_u8_u8(size_t m, size_t n, size_t p, const double *A, const uint8_t *B, uint8_t *C, double scale);
-int matmul_i8_f32_f32(size_t m, size_t n, size_t p, const int8_t *A, const float *B, float *C, double scale);
-int matmul_i8_f32_f64(size_t m, size_t n, size_t p, const int8_t *A, const float *B, double *C, double scale);
-int matmul_i8_f32_i8(size_t m, size_t n, size_t p, const int8_t *A, const float *B, int8_t *C, double scale);
-int matmul_i8_f32_u8(size_t m, size_t n, size_t p, const int8_t *A, const float *B, uint8_t *C, double scale);
-int matmul_i8_f64_f32(size_t m, size_t n, size_t p, const int8_t *A, const double *B, float *C, double scale);
-int matmul_i8_f64_f64(size_t m, size_t n, size_t p, const int8_t *A, const double *B, double *C, double scale);
-int matmul_i8_f64_i8(size_t m, size_t n, size_t p, const int8_t *A, const double *B, int8_t *C, double scale);
-int matmul_i8_f64_u8(size_t m, size_t n, size_t p, const int8_t *A, const double *B, uint8_t *C, double scale);
-int matmul_i8_i8_f32(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, float *C, double scale);
-int matmul_i8_i8_f64(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, double *C, double scale);
-int matmul_i8_i8_i8(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, int8_t *C, double scale);
-int matmul_i8_i8_u8(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, uint8_t *C, double scale);
-int matmul_i8_u8_f32(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, float *C, double scale);
-int matmul_i8_u8_f64(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, double *C, double scale);
-int matmul_i8_u8_i8(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, int8_t *C, double scale);
-int matmul_i8_u8_u8(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, uint8_t *C, double scale);
-int matmul_u8_f32_f32(size_t m, size_t n, size_t p, const uint8_t *A, const float *B, float *C, double scale);
-int matmul_u8_f32_f64(size_t m, size_t n, size_t p, const uint8_t *A, const float *B, double *C, double scale);
-int matmul_u8_f32_i8(size_t m, size_t n, size_t p, const uint8_t *A, const float *B, int8_t *C, double scale);
-int matmul_u8_f32_u8(size_t m, size_t n, size_t p, const uint8_t *A, const float *B, uint8_t *C, double scale);
-int matmul_u8_f64_f32(size_t m, size_t n, size_t p, const uint8_t *A, const double *B, float *C, double scale);
-int matmul_u8_f64_f64(size_t m, size_t n, size_t p, const uint8_t *A, const double *B, double *C, double scale);
-int matmul_u8_f64_i8(size_t m, size_t n, size_t p, const uint8_t *A, const double *B, int8_t *C, double scale);
-int matmul_u8_f64_u8(size_t m, size_t n, size_t p, const uint8_t *A, const double *B, uint8_t *C, double scale);
-int matmul_u8_i8_f32(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, float *C, double scale);
-int matmul_u8_i8_f64(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, double *C, double scale);
-int matmul_u8_i8_i8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, int8_t *C, double scale);
-int matmul_u8_i8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, uint8_t *C, double scale);
-int matmul_u8_u8_f32(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, float *C, double scale);
-int matmul_u8_u8_f64(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, double *C, double scale);
-int matmul_u8_u8_i8(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, int8_t *C, double scale);
-int matmul_u8_u8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, uint8_t *C, double scale);
+extern int (*matmul_f32_f32_f32)(size_t, size_t, size_t, const float *, const float *, float *, double);
+extern int (*matmul_f32_f32_f64)(size_t, size_t, size_t, const float *, const float *, double *, double);
+extern int (*matmul_f32_f32_i8)(size_t, size_t, size_t, const float *, const float *, int8_t *, double);
+extern int (*matmul_f32_f32_u8)(size_t, size_t, size_t, const float *, const float *, uint8_t *, double);
+extern int (*matmul_f32_f64_f32)(size_t, size_t, size_t, const float *, const double *, float *, double);
+extern int (*matmul_f32_f64_f64)(size_t, size_t, size_t, const float *, const double *, double *, double);
+extern int (*matmul_f32_f64_i8)(size_t, size_t, size_t, const float *, const double *, int8_t *, double);
+extern int (*matmul_f32_f64_u8)(size_t, size_t, size_t, const float *, const double *, uint8_t *, double);
+extern int (*matmul_f32_i8_f32)(size_t, size_t, size_t, const float *, const int8_t *, float *, double);
+extern int (*matmul_f32_i8_f64)(size_t, size_t, size_t, const float *, const int8_t *, double *, double);
+extern int (*matmul_f32_i8_i8)(size_t, size_t, size_t, const float *, const int8_t *, int8_t *, double);
+extern int (*matmul_f32_i8_u8)(size_t, size_t, size_t, const float *, const int8_t *, uint8_t *, double);
+extern int (*matmul_f32_u8_f32)(size_t, size_t, size_t, const float *, const uint8_t *, float *, double);
+extern int (*matmul_f32_u8_f64)(size_t, size_t, size_t, const float *, const uint8_t *, double *, double);
+extern int (*matmul_f32_u8_i8)(size_t, size_t, size_t, const float *, const uint8_t *, int8_t *, double);
+extern int (*matmul_f32_u8_u8)(size_t, size_t, size_t, const float *, const uint8_t *, uint8_t *, double);
+extern int (*matmul_f64_f32_f32)(size_t, size_t, size_t, const double *, const float *, float *, double);
+extern int (*matmul_f64_f32_f64)(size_t, size_t, size_t, const double *, const float *, double *, double);
+extern int (*matmul_f64_f32_i8)(size_t, size_t, size_t, const double *, const float *, int8_t *, double);
+extern int (*matmul_f64_f32_u8)(size_t, size_t, size_t, const double *, const float *, uint8_t *, double);
+extern int (*matmul_f64_f64_f32)(size_t, size_t, size_t, const double *, const double *, float *, double);
+extern int (*matmul_f64_f64_f64)(size_t, size_t, size_t, const double *, const double *, double *, double);
+extern int (*matmul_f64_f64_i8)(size_t, size_t, size_t, const double *, const double *, int8_t *, double);
+extern int (*matmul_f64_f64_u8)(size_t, size_t, size_t, const double *, const double *, uint8_t *, double);
+extern int (*matmul_f64_i8_f32)(size_t, size_t, size_t, const double *, const int8_t *, float *, double);
+extern int (*matmul_f64_i8_f64)(size_t, size_t, size_t, const double *, const int8_t *, double *, double);
+extern int (*matmul_f64_i8_i8)(size_t, size_t, size_t, const double *, const int8_t *, int8_t *, double);
+extern int (*matmul_f64_i8_u8)(size_t, size_t, size_t, const double *, const int8_t *, uint8_t *, double);
+extern int (*matmul_f64_u8_f32)(size_t, size_t, size_t, const double *, const uint8_t *, float *, double);
+extern int (*matmul_f64_u8_f64)(size_t, size_t, size_t, const double *, const uint8_t *, double *, double);
+extern int (*matmul_f64_u8_i8)(size_t, size_t, size_t, const double *, const uint8_t *, int8_t *, double);
+extern int (*matmul_f64_u8_u8)(size_t, size_t, size_t, const double *, const uint8_t *, uint8_t *, double);
+extern int (*matmul_i8_f32_f32)(size_t, size_t, size_t, const int8_t *, const float *, float *, double);
+extern int (*matmul_i8_f32_f64)(size_t, size_t, size_t, const int8_t *, const float *, double *, double);
+extern int (*matmul_i8_f32_i8)(size_t, size_t, size_t, const int8_t *, const float *, int8_t *, double);
+extern int (*matmul_i8_f32_u8)(size_t, size_t, size_t, const int8_t *, const float *, uint8_t *, double);
+extern int (*matmul_i8_f64_f32)(size_t, size_t, size_t, const int8_t *, const double *, float *, double);
+extern int (*matmul_i8_f64_f64)(size_t, size_t, size_t, const int8_t *, const double *, double *, double);
+extern int (*matmul_i8_f64_i8)(size_t, size_t, size_t, const int8_t *, const double *, int8_t *, double);
+extern int (*matmul_i8_f64_u8)(size_t, size_t, size_t, const int8_t *, const double *, uint8_t *, double);
+extern int (*matmul_i8_i8_f32)(size_t, size_t, size_t, const int8_t *, const int8_t *, float *, double);
+extern int (*matmul_i8_i8_f64)(size_t, size_t, size_t, const int8_t *, const int8_t *, double *, double);
+extern int (*matmul_i8_i8_i8)(size_t, size_t, size_t, const int8_t *, const int8_t *, int8_t *, double);
+extern int (*matmul_i8_i8_u8)(size_t, size_t, size_t, const int8_t *, const int8_t *, uint8_t *, double);
+extern int (*matmul_i8_u8_f32)(size_t, size_t, size_t, const int8_t *, const uint8_t *, float *, double);
+extern int (*matmul_i8_u8_f64)(size_t, size_t, size_t, const int8_t *, const uint8_t *, double *, double);
+extern int (*matmul_i8_u8_i8)(size_t, size_t, size_t, const int8_t *, const uint8_t *, int8_t *, double);
+extern int (*matmul_i8_u8_u8)(size_t, size_t, size_t, const int8_t *, const uint8_t *, uint8_t *, double);
+extern int (*matmul_u8_f32_f32)(size_t, size_t, size_t, const uint8_t *, const float *, float *, double);
+extern int (*matmul_u8_f32_f64)(size_t, size_t, size_t, const uint8_t *, const float *, double *, double);
+extern int (*matmul_u8_f32_i8)(size_t, size_t, size_t, const uint8_t *, const float *, int8_t *, double);
+extern int (*matmul_u8_f32_u8)(size_t, size_t, size_t, const uint8_t *, const float *, uint8_t *, double);
+extern int (*matmul_u8_f64_f32)(size_t, size_t, size_t, const uint8_t *, const double *, float *, double);
+extern int (*matmul_u8_f64_f64)(size_t, size_t, size_t, const uint8_t *, const double *, double *, double);
+extern int (*matmul_u8_f64_i8)(size_t, size_t, size_t, const uint8_t *, const double *, int8_t *, double);
+extern int (*matmul_u8_f64_u8)(size_t, size_t, size_t, const uint8_t *, const double *, uint8_t *, double);
+extern int (*matmul_u8_i8_f32)(size_t, size_t, size_t, const uint8_t *, const int8_t *, float *, double);
+extern int (*matmul_u8_i8_f64)(size_t, size_t, size_t, const uint8_t *, const int8_t *, double *, double);
+extern int (*matmul_u8_i8_i8)(size_t, size_t, size_t, const uint8_t *, const int8_t *, int8_t *, double);
+extern int (*matmul_u8_i8_u8)(size_t, size_t, size_t, const uint8_t *, const int8_t *, uint8_t *, double);
+extern int (*matmul_u8_u8_f32)(size_t, size_t, size_t, const uint8_t *, const uint8_t *, float *, double);
+extern int (*matmul_u8_u8_f64)(size_t, size_t, size_t, const uint8_t *, const uint8_t *, double *, double);
+extern int (*matmul_u8_u8_i8)(size_t, size_t, size_t, const uint8_t *, const uint8_t *, int8_t *, double);
+extern int (*matmul_u8_u8_u8)(size_t, size_t, size_t, const uint8_t *, const uint8_t *, uint8_t *, double);
int matmul_scalar_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const float *B, float *C, double scale);
int matmul_scalar_f32_f32_f64(size_t m, size_t n, size_t p, const float *A, const float *B, double *C, double scale);
@@ -209,6 +209,44 @@ int matmul_avx512_f64_f64_f32(size_t m, size_t n, size_t p, const double *A, con
int matmul_avx512_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, const double *B, double *C, double scale);
#endif
+#ifdef __AVX2__
+int matmul_avx2_i8_i8_f32(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, float *C, double scale);
+int matmul_avx2_i8_i8_f64(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, double *C, double scale);
+int matmul_avx2_i8_i8_i8(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, int8_t *C, double scale);
+int matmul_avx2_i8_i8_u8(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, uint8_t *C, double scale);
+int matmul_avx2_i8_u8_f32(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, float *C, double scale);
+int matmul_avx2_i8_u8_f64(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, double *C, double scale);
+int matmul_avx2_i8_u8_i8(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, int8_t *C, double scale);
+int matmul_avx2_i8_u8_u8(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, uint8_t *C, double scale);
+int matmul_avx2_u8_i8_f32(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, float *C, double scale);
+int matmul_avx2_u8_i8_f64(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, double *C, double scale);
+int matmul_avx2_u8_i8_i8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, int8_t *C, double scale);
+int matmul_avx2_u8_i8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, uint8_t *C, double scale);
+int matmul_avx2_u8_u8_f32(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, float *C, double scale);
+int matmul_avx2_u8_u8_f64(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, double *C, double scale);
+int matmul_avx2_u8_u8_i8(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, int8_t *C, double scale);
+int matmul_avx2_u8_u8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, uint8_t *C, double scale);
+#endif
+
+#ifdef __AVX512F__
+int matmul_avx512_i8_i8_f32(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, float *C, double scale);
+int matmul_avx512_i8_i8_f64(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, double *C, double scale);
+int matmul_avx512_i8_i8_i8(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, int8_t *C, double scale);
+int matmul_avx512_i8_i8_u8(size_t m, size_t n, size_t p, const int8_t *A, const int8_t *B, uint8_t *C, double scale);
+int matmul_avx512_i8_u8_f32(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, float *C, double scale);
+int matmul_avx512_i8_u8_f64(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, double *C, double scale);
+int matmul_avx512_i8_u8_i8(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, int8_t *C, double scale);
+int matmul_avx512_i8_u8_u8(size_t m, size_t n, size_t p, const int8_t *A, const uint8_t *B, uint8_t *C, double scale);
+int matmul_avx512_u8_i8_f32(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, float *C, double scale);
+int matmul_avx512_u8_i8_f64(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, double *C, double scale);
+int matmul_avx512_u8_i8_i8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, int8_t *C, double scale);
+int matmul_avx512_u8_i8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, uint8_t *C, double scale);
+int matmul_avx512_u8_u8_f32(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, float *C, double scale);
+int matmul_avx512_u8_u8_f64(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, double *C, double scale);
+int matmul_avx512_u8_u8_i8(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, int8_t *C, double scale);
+int matmul_avx512_u8_u8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const uint8_t *B, uint8_t *C, double scale);
+#endif
+
#define matmul(m, n, p, A, B, C, scale) \
_Generic((A), \
float: _Generic((B), \