matmul.c

Matrix multiplication helper library
git clone git://git.finwo.net/lib/matmul.c
Log | Files | Refs | README | LICENSE

matmul.c (55943B)


      1 /*
      2  * Copyright (c) 2026 finwo
      3  *
      4  * Permission is hereby granted, free of charge, to any person obtaining a copy of
      5  * this software and associated documentation files (the "Software"), to use, copy,
      6  * modify, and distribute the Software, subject to the following conditions:
      7  *
      8  *  1. Redistributions of source code must retain the above copyright notice, this
      9  *     list of conditions, and the following disclaimer.
     10  *
     11  *  2. Redistributions in binary form, or any public offering of the Software
     12  *     (including hosted or managed services), must reproduce the above copyright
     13  *     notice, this list of conditions, and the following disclaimer in the
     14  *     documentation and/or other materials provided.
     15  *
     16  *  3. Any redistribution or public offering of the Software must clearly attribute
     17  *     the Software to the original copyright holder, reference this License, and
     18  *     include a link to the official project repository or website.
     19  *
     20  *  4. The Software may not be renamed, rebranded, or marketed in a manner that
     21  *     implies it is an independent or proprietary product. Derivative works must
     22  *     clearly state that they are based on the Software.
     23  *
     24  *  5. Modifications to copies of the Software must carry prominent notices stating
     25  *     that changes were made, the nature of the modifications, and the date of the
     26  *     modifications.
     27  *
     28  * Any violation of these conditions terminates the permissions granted herein.
     29  *
     30  * THIS SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
     31  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
     32  * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE COPYRIGHT
     33  * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
     34  * WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
     35  * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
     36  */
     37 
     38 #define _GNU_SOURCE
     39 #include "matmul.h"
     40 
     41 #include <stdio.h>
     42 #include <stdlib.h>
     43 #include <string.h>
     44 
     45 #ifdef __AVX2__
     46 #include <immintrin.h>
     47 #endif
     48 
     49 #ifdef __AVX512F__
     50 #include <immintrin.h>
     51 #endif
     52 
     53 #ifdef __AVXVNNI__
     54 #include <immintrin.h>
     55 #endif
     56 
     57 #ifdef __AVX512VNNI__
     58 #include <immintrin.h>
     59 #endif
     60 
     61 #ifdef _OPENMP
     62 #include <omp.h>
     63 #endif
     64 
     65 #define MATMUL_FLAG_SCALAR     (1 << 0)
     66 #define MATMUL_FLAG_AVX2       (1 << 1)
     67 #define MATMUL_FLAG_AVXVNNI    (1 << 2)
     68 #define MATMUL_FLAG_AVX512     (1 << 3)
     69 #define MATMUL_FLAG_AVX512VNNI (1 << 4)
     70 
     71 typedef uint32_t matmul_feature_t;
     72 
     73 static matmul_feature_t g_feature     = 0;
     74 static int              g_initialized = 0;
     75 
     76 static void init_feature(void) {
     77   g_feature = MATMUL_FLAG_SCALAR;
     78 #ifdef __AVX512VNNI__
     79   if (__builtin_cpu_supports("avx512vnni")) g_feature |= MATMUL_FLAG_AVX512VNNI;
     80 #endif
     81 #ifdef __AVX512F__
     82   if (__builtin_cpu_supports("avx512f")) g_feature |= MATMUL_FLAG_AVX512;
     83 #endif
     84 #ifdef __AVXVNNI__
     85   if (__builtin_cpu_supports("avxvnni")) g_feature |= MATMUL_FLAG_AVXVNNI;
     86 #endif
     87 #ifdef __AVX2__
     88   if (__builtin_cpu_supports("avx2")) g_feature |= MATMUL_FLAG_AVX2;
     89 #endif
     90 }
     91 
     92 matmul_feature_t matmul_get_feature(void) {
     93   if (!g_initialized) {
     94     init_feature();
     95     g_initialized = 1;
     96   }
     97   return g_feature;
     98 }
     99 
    100 const char *matmul_get_feature_name(matmul_feature_t feat) {
    101   if (feat & MATMUL_FLAG_AVX512VNNI) return "avx512vnni";
    102   if (feat & MATMUL_FLAG_AVX512) return "avx512";
    103   if (feat & MATMUL_FLAG_AVXVNNI) return "avxvnni";
    104   if (feat & MATMUL_FLAG_AVX2) return "avx2";
    105   if (feat & MATMUL_FLAG_SCALAR) return "scalar";
    106   return "unknown";
    107 }
    108 
    109 int matmul_scalar_u8_i8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, uint8_t *C, double scale) {
    110   (void)scale;
    111   const size_t ib = 64;
    112   const size_t jb = 64;
    113   const size_t kb = 32;
    114 
    115 #pragma omp parallel for schedule(static)
    116   for (size_t ii = 0; ii < m; ii += ib) {
    117     size_t i_end = (ii + ib < m) ? ii + ib : m;
    118     for (size_t jj = 0; jj < p; jj += jb) {
    119       size_t  j_end = (jj + jb < p) ? jj + jb : p;
    120       size_t  ti    = i_end - ii;
    121       size_t  tj    = j_end - jj;
    122       int32_t acc[64 * 64];
    123       memset(acc, 0, ti * tj * sizeof(int32_t));
    124 
    125       for (size_t kk = 0; kk < n; kk += kb) {
    126         size_t k_end = (kk + kb < n) ? kk + kb : n;
    127         for (size_t i = ii; i < i_end; i++) {
    128           size_t li = i - ii;
    129           for (size_t j = jj; j < j_end; j++) {
    130             size_t lj  = j - jj;
    131             int    sum = 0;
    132             for (size_t k = kk; k < k_end; k++) {
    133               sum += (int)A[i * n + k] * (int)B[k * p + j];
    134             }
    135             acc[li * tj + lj] += sum;
    136           }
    137         }
    138       }
    139 
    140       for (size_t i = ii; i < i_end; i++) {
    141         size_t li = i - ii;
    142         for (size_t j = jj; j < j_end; j++) {
    143           size_t lj = j - jj;
    144           int    v  = acc[li * tj + lj];
    145           if (scale > 1.0) v = (int)(v / scale);
    146           if (v > 255)
    147             v = 255;
    148           else if (v < 0)
    149             v = 0;
    150           C[i * p + j] = (uint8_t)v;
    151         }
    152       }
    153     }
    154   }
    155   return 0;
    156 }
    157 
    158 #ifdef __AVX512VNNI__
    159 static void pack_b_i8(size_t n, size_t p, const int8_t *B, int8_t *B_packed) {
    160   size_t n4  = n / 4;
    161   size_t p16 = p / 16;
    162   for (size_t j16 = 0; j16 < p16; j16++) {
    163     for (size_t k4 = 0; k4 < n4; k4++) {
    164       int8_t *dst = &B_packed[(j16 * n4 + k4) * 64];
    165       for (size_t dj = 0; dj < 16; dj++) {
    166         size_t j = j16 * 16 + dj;
    167         for (size_t dk = 0; dk < 4; dk++) {
    168           dst[dj * 4 + dk] = B[(k4 * 4 + dk) * p + j];
    169         }
    170       }
    171     }
    172   }
    173 }
    174 
    175 int matmul_avx512vnni_u8_i8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, uint8_t *C,
    176                                double scale) {
    177   (void)scale;
    178 
    179   size_t n4  = n / 4;
    180   size_t p16 = p / 16;
    181 
    182   int8_t *B_packed;
    183   if (posix_memalign((void **)&B_packed, 64, p16 * n4 * 64) != 0) return -1;
    184   pack_b_i8(n, p, B, B_packed);
    185 
    186   const uint32_t *A32 = (const uint32_t *)A;
    187 
    188 #pragma omp parallel for schedule(static)
    189   for (size_t i = 0; i < m; i++) {
    190     for (size_t j16 = 0; j16 < p16; j16++) {
    191       __m512i result = _mm512_setzero_si512();
    192       for (size_t k4 = 0; k4 < n4; k4++) {
    193         __m512i a_val = _mm512_set1_epi32(A32[i * n4 + k4]);
    194         __m512i b_val = _mm512_load_si512((__m512i const *)&B_packed[(j16 * n4 + k4) * 64]);
    195         result        = _mm512_dpbusd_epi32(result, a_val, b_val);
    196       }
    197       int32_t tmp[16] __attribute__((aligned(64)));
    198       _mm512_store_si512(tmp, result);
    199       for (size_t dj = 0; dj < 16; dj++) {
    200         int32_t v = tmp[dj];
    201         if (scale > 1.0) v = (int32_t)(v / scale);
    202         if (v > 255)
    203           v = 255;
    204         else if (v < 0)
    205           v = 0;
    206         C[i * p + j16 * 16 + dj] = (uint8_t)v;
    207       }
    208     }
    209     for (size_t j = p16 * 16; j < p; j++) {
    210       int32_t sum = 0;
    211       for (size_t k = 0; k < n; k++) {
    212         sum += (int)A[i * n + k] * (int)B[k * p + j];
    213       }
    214       if (scale > 1.0) sum = (int32_t)(sum / scale);
    215       if (sum > 255)
    216         sum = 255;
    217       else if (sum < 0)
    218         sum = 0;
    219       C[i * p + j] = (uint8_t)sum;
    220     }
    221   }
    222 
    223   free(B_packed);
    224   return 0;
    225 }
    226 #endif
    227 
    228 static int _matmul_u8_i8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, uint8_t *C, double scale) {
    229   static int initialized = 0;
    230   if (!initialized) {
    231     matmul_feature_t feat = matmul_get_feature();
    232 #ifdef __AVX512VNNI__
    233     if (feat & MATMUL_FLAG_AVX512VNNI)
    234       matmul_u8_i8_u8 = matmul_avx512vnni_u8_i8_u8;
    235     else
    236 #endif
    237       matmul_u8_i8_u8 = matmul_scalar_u8_i8_u8;
    238     initialized = 1;
    239   }
    240   return matmul_u8_i8_u8(m, n, p, A, B, C, scale);
    241 }
    242 
    243 int (*matmul_u8_i8_u8)(size_t, size_t, size_t, const uint8_t *, const int8_t *, uint8_t *, double) = _matmul_u8_i8_u8;
    244 
    245 /* ========================================================================== */
    246 /* f32_f32_f32 implementations                                                */
    247 /* ========================================================================== */
    248 
    249 int matmul_scalar_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const float *B, float *C, double scale) {
    250   const size_t ib = 64;
    251   const size_t jb = 64;
    252   const size_t kb = 16;
    253 
    254 #pragma omp parallel for schedule(static)
    255   for (size_t ii = 0; ii < m; ii += ib) {
    256     size_t i_end = (ii + ib < m) ? ii + ib : m;
    257     for (size_t jj = 0; jj < p; jj += jb) {
    258       size_t j_end = (jj + jb < p) ? jj + jb : p;
    259       size_t ti    = i_end - ii;
    260       size_t tj    = j_end - jj;
    261       double acc[64 * 64];
    262       memset(acc, 0, ti * tj * sizeof(double));
    263 
    264       for (size_t kk = 0; kk < n; kk += kb) {
    265         size_t k_end = (kk + kb < n) ? kk + kb : n;
    266         for (size_t i = ii; i < i_end; i++) {
    267           size_t li = i - ii;
    268           for (size_t j = jj; j < j_end; j++) {
    269             size_t lj  = j - jj;
    270             double sum = 0.0;
    271             for (size_t k = kk; k < k_end; k++) {
    272               sum += (double)A[i * n + k] * (double)B[k * p + j];
    273             }
    274             acc[li * tj + lj] += sum;
    275           }
    276         }
    277       }
    278 
    279       for (size_t i = ii; i < i_end; i++) {
    280         size_t li = i - ii;
    281         for (size_t j = jj; j < j_end; j++) {
    282           size_t lj = j - jj;
    283           double v  = acc[li * tj + lj];
    284           if (scale > 1.0) v /= scale;
    285           C[i * p + j] = (float)v;
    286         }
    287       }
    288     }
    289   }
    290   return 0;
    291 }
    292 
    293 #ifdef __AVX2__
    294 static void pack_b_f32(size_t n, size_t p, const float *B, float *B_packed) {
    295   size_t n8 = n / 8;
    296   size_t p8 = p / 8;
    297   for (size_t j8 = 0; j8 < p8; j8++) {
    298     for (size_t k8 = 0; k8 < n8; k8++) {
    299       float *dst = &B_packed[(j8 * n8 + k8) * 64];
    300       for (size_t dk = 0; dk < 8; dk++) {
    301         size_t k = k8 * 8 + dk;
    302         for (size_t dj = 0; dj < 8; dj++) {
    303           size_t j         = j8 * 8 + dj;
    304           dst[dk * 8 + dj] = B[k * p + j];
    305         }
    306       }
    307     }
    308   }
    309 }
    310 
    311 int matmul_avx2_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const float *B, float *C, double scale) {
    312   const size_t ib = 64;
    313   const size_t jb = 64;
    314   const size_t kb = 16;
    315 
    316   size_t n8 = n / 8;
    317   size_t p8 = p / 8;
    318 
    319   float *B_packed;
    320   if (posix_memalign((void **)&B_packed, 64, p8 * n8 * 64 * sizeof(float)) != 0) return -1;
    321   pack_b_f32(n, p, B, B_packed);
    322 
    323   float inv_scale = (scale > 1.0) ? 1.0f / (float)scale : 1.0f;
    324 
    325 #pragma omp parallel for schedule(static)
    326   for (size_t ii = 0; ii < m; ii += ib) {
    327     size_t i_end = (ii + ib < m) ? ii + ib : m;
    328     for (size_t jj = 0; jj < p8 * 8; jj += jb) {
    329       size_t j_end = (jj + jb < p8 * 8) ? jj + jb : p8 * 8;
    330       size_t ti    = i_end - ii;
    331       size_t tj    = j_end - jj;
    332       float  acc[64 * 64];
    333       memset(acc, 0, ti * tj * sizeof(float));
    334 
    335       for (size_t i4 = ii; i4 + 4 <= i_end; i4 += 4) {
    336         size_t li0 = i4 - ii;
    337         size_t li1 = li0 + 1;
    338         size_t li2 = li0 + 2;
    339         size_t li3 = li0 + 3;
    340         size_t rn0 = i4 * n;
    341         size_t rn1 = (i4 + 1) * n;
    342         size_t rn2 = (i4 + 2) * n;
    343         size_t rn3 = (i4 + 3) * n;
    344 
    345         for (size_t j = jj; j + 8 <= j_end; j += 8) {
    346           size_t lj    = j - jj;
    347           size_t j8idx = j / 8;
    348           __m256 acc00 = _mm256_setzero_ps();
    349           __m256 acc01 = _mm256_setzero_ps();
    350           __m256 acc02 = _mm256_setzero_ps();
    351           __m256 acc03 = _mm256_setzero_ps();
    352           __m256 acc10 = _mm256_setzero_ps();
    353           __m256 acc11 = _mm256_setzero_ps();
    354           __m256 acc12 = _mm256_setzero_ps();
    355           __m256 acc13 = _mm256_setzero_ps();
    356           __m256 acc20 = _mm256_setzero_ps();
    357           __m256 acc21 = _mm256_setzero_ps();
    358           __m256 acc22 = _mm256_setzero_ps();
    359           __m256 acc23 = _mm256_setzero_ps();
    360           __m256 acc30 = _mm256_setzero_ps();
    361           __m256 acc31 = _mm256_setzero_ps();
    362           __m256 acc32 = _mm256_setzero_ps();
    363           __m256 acc33 = _mm256_setzero_ps();
    364 
    365           for (size_t kk = 0; kk < n; kk += kb) {
    366             size_t k_end  = (kk + kb < n) ? kk + kb : n;
    367             size_t k_end8 = kk + (k_end - kk) / 8 * 8;
    368             size_t k      = kk;
    369             for (; k + 4 <= k_end8; k += 4) {
    370               size_t k8_0 = k / 8, dk_0 = k % 8;
    371               size_t k8_1 = (k + 1) / 8, dk_1 = (k + 1) % 8;
    372               size_t k8_2 = (k + 2) / 8, dk_2 = (k + 2) % 8;
    373               size_t k8_3 = (k + 3) / 8, dk_3 = (k + 3) % 8;
    374               __m256 a0 = _mm256_set1_ps(A[rn0 + k]);
    375               __m256 a1 = _mm256_set1_ps(A[rn0 + k + 1]);
    376               __m256 a2 = _mm256_set1_ps(A[rn0 + k + 2]);
    377               __m256 a3 = _mm256_set1_ps(A[rn0 + k + 3]);
    378               __m256 b0 = _mm256_load_ps(&B_packed[(j8idx * n8 + k8_0) * 64 + dk_0 * 8]);
    379               __m256 b1 = _mm256_load_ps(&B_packed[(j8idx * n8 + k8_1) * 64 + dk_1 * 8]);
    380               __m256 b2 = _mm256_load_ps(&B_packed[(j8idx * n8 + k8_2) * 64 + dk_2 * 8]);
    381               __m256 b3 = _mm256_load_ps(&B_packed[(j8idx * n8 + k8_3) * 64 + dk_3 * 8]);
    382               acc00     = _mm256_fmadd_ps(a0, b0, acc00);
    383               acc01     = _mm256_fmadd_ps(a1, b1, acc01);
    384               acc02     = _mm256_fmadd_ps(a2, b2, acc02);
    385               acc03     = _mm256_fmadd_ps(a3, b3, acc03);
    386               a0        = _mm256_set1_ps(A[rn1 + k]);
    387               a1        = _mm256_set1_ps(A[rn1 + k + 1]);
    388               a2        = _mm256_set1_ps(A[rn1 + k + 2]);
    389               a3        = _mm256_set1_ps(A[rn1 + k + 3]);
    390               acc10     = _mm256_fmadd_ps(a0, b0, acc10);
    391               acc11     = _mm256_fmadd_ps(a1, b1, acc11);
    392               acc12     = _mm256_fmadd_ps(a2, b2, acc12);
    393               acc13     = _mm256_fmadd_ps(a3, b3, acc13);
    394               a0        = _mm256_set1_ps(A[rn2 + k]);
    395               a1        = _mm256_set1_ps(A[rn2 + k + 1]);
    396               a2        = _mm256_set1_ps(A[rn2 + k + 2]);
    397               a3        = _mm256_set1_ps(A[rn2 + k + 3]);
    398               acc20     = _mm256_fmadd_ps(a0, b0, acc20);
    399               acc21     = _mm256_fmadd_ps(a1, b1, acc21);
    400               acc22     = _mm256_fmadd_ps(a2, b2, acc22);
    401               acc23     = _mm256_fmadd_ps(a3, b3, acc23);
    402               a0        = _mm256_set1_ps(A[rn3 + k]);
    403               a1        = _mm256_set1_ps(A[rn3 + k + 1]);
    404               a2        = _mm256_set1_ps(A[rn3 + k + 2]);
    405               a3        = _mm256_set1_ps(A[rn3 + k + 3]);
    406               acc30     = _mm256_fmadd_ps(a0, b0, acc30);
    407               acc31     = _mm256_fmadd_ps(a1, b1, acc31);
    408               acc32     = _mm256_fmadd_ps(a2, b2, acc32);
    409               acc33     = _mm256_fmadd_ps(a3, b3, acc33);
    410             }
    411             for (; k < k_end; k++) {
    412               size_t k8 = k / 8, dk = k % 8;
    413               __m256 a_bcast = _mm256_set1_ps(A[rn0 + k]);
    414               __m256 b_val   = _mm256_load_ps(&B_packed[(j8idx * n8 + k8) * 64 + dk * 8]);
    415               acc00          = _mm256_fmadd_ps(a_bcast, b_val, acc00);
    416               a_bcast        = _mm256_set1_ps(A[rn1 + k]);
    417               acc10          = _mm256_fmadd_ps(a_bcast, b_val, acc10);
    418               a_bcast        = _mm256_set1_ps(A[rn2 + k]);
    419               acc20          = _mm256_fmadd_ps(a_bcast, b_val, acc20);
    420               a_bcast        = _mm256_set1_ps(A[rn3 + k]);
    421               acc30          = _mm256_fmadd_ps(a_bcast, b_val, acc30);
    422             }
    423           }
    424 
    425           acc00 = _mm256_add_ps(acc00, acc01);
    426           acc02 = _mm256_add_ps(acc02, acc03);
    427           acc00 = _mm256_add_ps(acc00, acc02);
    428           acc10 = _mm256_add_ps(acc10, acc11);
    429           acc12 = _mm256_add_ps(acc12, acc13);
    430           acc10 = _mm256_add_ps(acc10, acc12);
    431           acc20 = _mm256_add_ps(acc20, acc21);
    432           acc22 = _mm256_add_ps(acc22, acc23);
    433           acc20 = _mm256_add_ps(acc20, acc22);
    434           acc30 = _mm256_add_ps(acc30, acc31);
    435           acc32 = _mm256_add_ps(acc32, acc33);
    436           acc30 = _mm256_add_ps(acc30, acc32);
    437 
    438           float tmp[8] __attribute__((aligned(32)));
    439           _mm256_store_ps(tmp, acc00);
    440           for (size_t dj = 0; dj < 8; dj++) acc[li0 * tj + lj + dj] += tmp[dj];
    441           _mm256_store_ps(tmp, acc10);
    442           for (size_t dj = 0; dj < 8; dj++) acc[li1 * tj + lj + dj] += tmp[dj];
    443           _mm256_store_ps(tmp, acc20);
    444           for (size_t dj = 0; dj < 8; dj++) acc[li2 * tj + lj + dj] += tmp[dj];
    445           _mm256_store_ps(tmp, acc30);
    446           for (size_t dj = 0; dj < 8; dj++) acc[li3 * tj + lj + dj] += tmp[dj];
    447         }
    448 
    449         for (size_t j = jj + (tj / 8) * 8; j < j_end; j++) {
    450           size_t lj   = j - jj;
    451           double sum0 = 0.0;
    452           double sum1 = 0.0;
    453           double sum2 = 0.0;
    454           double sum3 = 0.0;
    455           for (size_t k = 0; k < n; k++) {
    456             sum0 += (double)A[rn0 + k] * (double)B[k * p + j];
    457             sum1 += (double)A[rn1 + k] * (double)B[k * p + j];
    458             sum2 += (double)A[rn2 + k] * (double)B[k * p + j];
    459             sum3 += (double)A[rn3 + k] * (double)B[k * p + j];
    460           }
    461           acc[li0 * tj + lj] += (float)sum0;
    462           acc[li1 * tj + lj] += (float)sum1;
    463           acc[li2 * tj + lj] += (float)sum2;
    464           acc[li3 * tj + lj] += (float)sum3;
    465         }
    466       }
    467 
    468       for (size_t i = ii + (i_end - ii) / 4 * 4; i < i_end; i++) {
    469         size_t li = i - ii;
    470         size_t rn = i * n;
    471         for (size_t j = jj; j + 8 <= j_end; j += 8) {
    472           size_t lj    = j - jj;
    473           size_t j8idx = j / 8;
    474           __m256 acc0  = _mm256_setzero_ps();
    475           __m256 acc1  = _mm256_setzero_ps();
    476           __m256 acc2  = _mm256_setzero_ps();
    477           __m256 acc3  = _mm256_setzero_ps();
    478 
    479           for (size_t kk = 0; kk < n; kk += kb) {
    480             size_t k_end  = (kk + kb < n) ? kk + kb : n;
    481             size_t k_end8 = kk + (k_end - kk) / 8 * 8;
    482             size_t k      = kk;
    483             for (; k + 4 <= k_end8; k += 4) {
    484               size_t k8_0 = k / 8, dk_0 = k % 8;
    485               size_t k8_1 = (k + 1) / 8, dk_1 = (k + 1) % 8;
    486               size_t k8_2 = (k + 2) / 8, dk_2 = (k + 2) % 8;
    487               size_t k8_3 = (k + 3) / 8, dk_3 = (k + 3) % 8;
    488               __m256 a0 = _mm256_set1_ps(A[rn + k]);
    489               __m256 a1 = _mm256_set1_ps(A[rn + k + 1]);
    490               __m256 a2 = _mm256_set1_ps(A[rn + k + 2]);
    491               __m256 a3 = _mm256_set1_ps(A[rn + k + 3]);
    492               __m256 b0 = _mm256_load_ps(&B_packed[(j8idx * n8 + k8_0) * 64 + dk_0 * 8]);
    493               __m256 b1 = _mm256_load_ps(&B_packed[(j8idx * n8 + k8_1) * 64 + dk_1 * 8]);
    494               __m256 b2 = _mm256_load_ps(&B_packed[(j8idx * n8 + k8_2) * 64 + dk_2 * 8]);
    495               __m256 b3 = _mm256_load_ps(&B_packed[(j8idx * n8 + k8_3) * 64 + dk_3 * 8]);
    496               acc0      = _mm256_fmadd_ps(a0, b0, acc0);
    497               acc1      = _mm256_fmadd_ps(a1, b1, acc1);
    498               acc2      = _mm256_fmadd_ps(a2, b2, acc2);
    499               acc3      = _mm256_fmadd_ps(a3, b3, acc3);
    500             }
    501             for (; k < k_end; k++) {
    502               size_t k8 = k / 8, dk = k % 8;
    503               __m256 a_bcast = _mm256_set1_ps(A[rn + k]);
    504               __m256 b_val   = _mm256_load_ps(&B_packed[(j8idx * n8 + k8) * 64 + dk * 8]);
    505               acc0           = _mm256_fmadd_ps(a_bcast, b_val, acc0);
    506             }
    507           }
    508 
    509           acc0 = _mm256_add_ps(acc0, acc1);
    510           acc2 = _mm256_add_ps(acc2, acc3);
    511           acc0 = _mm256_add_ps(acc0, acc2);
    512 
    513           float tmp[8] __attribute__((aligned(32)));
    514           _mm256_store_ps(tmp, acc0);
    515           for (size_t dj = 0; dj < 8; dj++) acc[li * tj + lj + dj] += tmp[dj];
    516         }
    517         for (size_t j = jj + (tj / 8) * 8; j < j_end; j++) {
    518           size_t lj = j - jj;
    519           for (size_t k = 0; k < n; k++) {
    520             acc[li * tj + lj] += (double)A[rn + k] * (double)B[k * p + j];
    521           }
    522         }
    523       }
    524 
    525       for (size_t i = ii; i < i_end; i++) {
    526         size_t li = i - ii;
    527         for (size_t j = jj; j < j_end; j++) {
    528           size_t lj    = j - jj;
    529           C[i * p + j] = acc[li * tj + lj] * inv_scale;
    530         }
    531       }
    532     }
    533   }
    534 
    535   for (size_t i = 0; i < m; i++) {
    536     for (size_t j = p8 * 8; j < p; j++) {
    537       double sum = 0.0;
    538       for (size_t k = 0; k < n; k++) {
    539         sum += (double)A[i * n + k] * (double)B[k * p + j];
    540       }
    541       C[i * p + j] = (float)(sum * inv_scale);
    542     }
    543   }
    544 
    545   free(B_packed);
    546   return 0;
    547 }
    548 #endif
    549 
    550 #ifdef __AVX512F__
    551 static void pack_b_f32_512(size_t n, size_t p, const float *B, float *B_packed) {
    552   size_t n16 = n / 16;
    553   size_t p16 = p / 16;
    554   for (size_t j16 = 0; j16 < p16; j16++) {
    555     for (size_t k16 = 0; k16 < n16; k16++) {
    556       float *dst = &B_packed[(j16 * n16 + k16) * 256];
    557       for (size_t dk = 0; dk < 16; dk++) {
    558         size_t k = k16 * 16 + dk;
    559         for (size_t dj = 0; dj < 16; dj++) {
    560           size_t j          = j16 * 16 + dj;
    561           dst[dk * 16 + dj] = B[k * p + j];
    562         }
    563       }
    564     }
    565   }
    566 }
    567 
    568 int matmul_avx512_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const float *B, float *C, double scale) {
    569   const size_t ib = 64;
    570   const size_t jb = 64;
    571   const size_t kb = 16;
    572 
    573   size_t n16 = n / 16;
    574   size_t p16 = p / 16;
    575 
    576   float *B_packed;
    577   if (posix_memalign((void **)&B_packed, 64, p16 * n16 * 256 * sizeof(float)) != 0) return -1;
    578   pack_b_f32_512(n, p, B, B_packed);
    579 
    580   float inv_scale = (scale > 1.0) ? 1.0f / (float)scale : 1.0f;
    581 
    582 #pragma omp parallel for schedule(static)
    583   for (size_t ii = 0; ii < m; ii += ib) {
    584     size_t i_end = (ii + ib < m) ? ii + ib : m;
    585     for (size_t jj = 0; jj < p16 * 16; jj += jb) {
    586       size_t j_end = (jj + jb < p16 * 16) ? jj + jb : p16 * 16;
    587       size_t ti    = i_end - ii;
    588       size_t tj    = j_end - jj;
    589       float  acc[64 * 64];
    590       memset(acc, 0, ti * tj * sizeof(float));
    591 
    592       for (size_t i4 = ii; i4 + 4 <= i_end; i4 += 4) {
    593         size_t li0 = i4 - ii;
    594         size_t li1 = li0 + 1;
    595         size_t li2 = li0 + 2;
    596         size_t li3 = li0 + 3;
    597         size_t rn0 = i4 * n;
    598         size_t rn1 = (i4 + 1) * n;
    599         size_t rn2 = (i4 + 2) * n;
    600         size_t rn3 = (i4 + 3) * n;
    601 
    602         for (size_t j = jj; j + 16 <= j_end; j += 16) {
    603           size_t lj     = j - jj;
    604           size_t j16idx = j / 16;
    605           __m512 acc00  = _mm512_setzero_ps();
    606           __m512 acc01  = _mm512_setzero_ps();
    607           __m512 acc02  = _mm512_setzero_ps();
    608           __m512 acc03  = _mm512_setzero_ps();
    609           __m512 acc10  = _mm512_setzero_ps();
    610           __m512 acc11  = _mm512_setzero_ps();
    611           __m512 acc12  = _mm512_setzero_ps();
    612           __m512 acc13  = _mm512_setzero_ps();
    613           __m512 acc20  = _mm512_setzero_ps();
    614           __m512 acc21  = _mm512_setzero_ps();
    615           __m512 acc22  = _mm512_setzero_ps();
    616           __m512 acc23  = _mm512_setzero_ps();
    617           __m512 acc30  = _mm512_setzero_ps();
    618           __m512 acc31  = _mm512_setzero_ps();
    619           __m512 acc32  = _mm512_setzero_ps();
    620           __m512 acc33  = _mm512_setzero_ps();
    621 
    622           for (size_t kk = 0; kk < n; kk += kb) {
    623             size_t k_end   = (kk + kb < n) ? kk + kb : n;
    624             size_t k_end16 = kk + (k_end - kk) / 16 * 16;
    625             size_t k       = kk;
    626             for (; k + 4 <= k_end16; k += 4) {
    627               size_t k16_0 = k / 16, dk_0 = k % 16;
    628               size_t k16_1 = (k + 1) / 16, dk_1 = (k + 1) % 16;
    629               size_t k16_2 = (k + 2) / 16, dk_2 = (k + 2) % 16;
    630               size_t k16_3 = (k + 3) / 16, dk_3 = (k + 3) % 16;
    631               __m512 a0 = _mm512_set1_ps(A[rn0 + k]);
    632               __m512 a1 = _mm512_set1_ps(A[rn0 + k + 1]);
    633               __m512 a2 = _mm512_set1_ps(A[rn0 + k + 2]);
    634               __m512 a3 = _mm512_set1_ps(A[rn0 + k + 3]);
    635               __m512 b0 = _mm512_load_ps(&B_packed[(j16idx * n16 + k16_0) * 256 + dk_0 * 16]);
    636               __m512 b1 = _mm512_load_ps(&B_packed[(j16idx * n16 + k16_1) * 256 + dk_1 * 16]);
    637               __m512 b2 = _mm512_load_ps(&B_packed[(j16idx * n16 + k16_2) * 256 + dk_2 * 16]);
    638               __m512 b3 = _mm512_load_ps(&B_packed[(j16idx * n16 + k16_3) * 256 + dk_3 * 16]);
    639               acc00     = _mm512_fmadd_ps(a0, b0, acc00);
    640               acc01     = _mm512_fmadd_ps(a1, b1, acc01);
    641               acc02     = _mm512_fmadd_ps(a2, b2, acc02);
    642               acc03     = _mm512_fmadd_ps(a3, b3, acc03);
    643               a0        = _mm512_set1_ps(A[rn1 + k]);
    644               a1        = _mm512_set1_ps(A[rn1 + k + 1]);
    645               a2        = _mm512_set1_ps(A[rn1 + k + 2]);
    646               a3        = _mm512_set1_ps(A[rn1 + k + 3]);
    647               acc10     = _mm512_fmadd_ps(a0, b0, acc10);
    648               acc11     = _mm512_fmadd_ps(a1, b1, acc11);
    649               acc12     = _mm512_fmadd_ps(a2, b2, acc12);
    650               acc13     = _mm512_fmadd_ps(a3, b3, acc13);
    651               a0        = _mm512_set1_ps(A[rn2 + k]);
    652               a1        = _mm512_set1_ps(A[rn2 + k + 1]);
    653               a2        = _mm512_set1_ps(A[rn2 + k + 2]);
    654               a3        = _mm512_set1_ps(A[rn2 + k + 3]);
    655               acc20     = _mm512_fmadd_ps(a0, b0, acc20);
    656               acc21     = _mm512_fmadd_ps(a1, b1, acc21);
    657               acc22     = _mm512_fmadd_ps(a2, b2, acc22);
    658               acc23     = _mm512_fmadd_ps(a3, b3, acc23);
    659               a0        = _mm512_set1_ps(A[rn3 + k]);
    660               a1        = _mm512_set1_ps(A[rn3 + k + 1]);
    661               a2        = _mm512_set1_ps(A[rn3 + k + 2]);
    662               a3        = _mm512_set1_ps(A[rn3 + k + 3]);
    663               acc30     = _mm512_fmadd_ps(a0, b0, acc30);
    664               acc31     = _mm512_fmadd_ps(a1, b1, acc31);
    665               acc32     = _mm512_fmadd_ps(a2, b2, acc32);
    666               acc33     = _mm512_fmadd_ps(a3, b3, acc33);
    667             }
    668             for (; k < k_end; k++) {
    669               size_t k16 = k / 16, dk = k % 16;
    670               __m512 a_bcast = _mm512_set1_ps(A[rn0 + k]);
    671               __m512 b_val   = _mm512_load_ps(&B_packed[(j16idx * n16 + k16) * 256 + dk * 16]);
    672               acc00          = _mm512_fmadd_ps(a_bcast, b_val, acc00);
    673               a_bcast        = _mm512_set1_ps(A[rn1 + k]);
    674               acc10          = _mm512_fmadd_ps(a_bcast, b_val, acc10);
    675               a_bcast        = _mm512_set1_ps(A[rn2 + k]);
    676               acc20          = _mm512_fmadd_ps(a_bcast, b_val, acc20);
    677               a_bcast        = _mm512_set1_ps(A[rn3 + k]);
    678               acc30          = _mm512_fmadd_ps(a_bcast, b_val, acc30);
    679             }
    680           }
    681 
    682           acc00 = _mm512_add_ps(acc00, acc01);
    683           acc02 = _mm512_add_ps(acc02, acc03);
    684           acc00 = _mm512_add_ps(acc00, acc02);
    685           acc10 = _mm512_add_ps(acc10, acc11);
    686           acc12 = _mm512_add_ps(acc12, acc13);
    687           acc10 = _mm512_add_ps(acc10, acc12);
    688           acc20 = _mm512_add_ps(acc20, acc21);
    689           acc22 = _mm512_add_ps(acc22, acc23);
    690           acc20 = _mm512_add_ps(acc20, acc22);
    691           acc30 = _mm512_add_ps(acc30, acc31);
    692           acc32 = _mm512_add_ps(acc32, acc33);
    693           acc30 = _mm512_add_ps(acc30, acc32);
    694 
    695           float tmp[16] __attribute__((aligned(64)));
    696           _mm512_store_ps(tmp, acc00);
    697           for (size_t dj = 0; dj < 16; dj++) acc[li0 * tj + lj + dj] += tmp[dj];
    698           _mm512_store_ps(tmp, acc10);
    699           for (size_t dj = 0; dj < 16; dj++) acc[li1 * tj + lj + dj] += tmp[dj];
    700           _mm512_store_ps(tmp, acc20);
    701           for (size_t dj = 0; dj < 16; dj++) acc[li2 * tj + lj + dj] += tmp[dj];
    702           _mm512_store_ps(tmp, acc30);
    703           for (size_t dj = 0; dj < 16; dj++) acc[li3 * tj + lj + dj] += tmp[dj];
    704         }
    705 
    706         for (size_t j = jj + (tj / 16) * 16; j < j_end; j++) {
    707           size_t lj   = j - jj;
    708           double sum0 = 0.0;
    709           double sum1 = 0.0;
    710           double sum2 = 0.0;
    711           double sum3 = 0.0;
    712           for (size_t k = 0; k < n; k++) {
    713             sum0 += (double)A[rn0 + k] * (double)B[k * p + j];
    714             sum1 += (double)A[rn1 + k] * (double)B[k * p + j];
    715             sum2 += (double)A[rn2 + k] * (double)B[k * p + j];
    716             sum3 += (double)A[rn3 + k] * (double)B[k * p + j];
    717           }
    718           acc[li0 * tj + lj] += (float)sum0;
    719           acc[li1 * tj + lj] += (float)sum1;
    720           acc[li2 * tj + lj] += (float)sum2;
    721           acc[li3 * tj + lj] += (float)sum3;
    722         }
    723       }
    724 
    725       for (size_t i = ii + (i_end - ii) / 4 * 4; i < i_end; i++) {
    726         size_t li = i - ii;
    727         size_t rn = i * n;
    728         for (size_t j = jj; j + 16 <= j_end; j += 16) {
    729           size_t lj     = j - jj;
    730           size_t j16idx = j / 16;
    731           __m512 acc0   = _mm512_setzero_ps();
    732           __m512 acc1   = _mm512_setzero_ps();
    733           __m512 acc2   = _mm512_setzero_ps();
    734           __m512 acc3   = _mm512_setzero_ps();
    735 
    736           for (size_t kk = 0; kk < n; kk += kb) {
    737             size_t k_end   = (kk + kb < n) ? kk + kb : n;
    738             size_t k_end16 = kk + (k_end - kk) / 16 * 16;
    739             size_t k       = kk;
    740             for (; k + 4 <= k_end16; k += 4) {
    741               size_t k16_0 = k / 16, dk_0 = k % 16;
    742               size_t k16_1 = (k + 1) / 16, dk_1 = (k + 1) % 16;
    743               size_t k16_2 = (k + 2) / 16, dk_2 = (k + 2) % 16;
    744               size_t k16_3 = (k + 3) / 16, dk_3 = (k + 3) % 16;
    745               __m512 a0 = _mm512_set1_ps(A[rn + k]);
    746               __m512 a1 = _mm512_set1_ps(A[rn + k + 1]);
    747               __m512 a2 = _mm512_set1_ps(A[rn + k + 2]);
    748               __m512 a3 = _mm512_set1_ps(A[rn + k + 3]);
    749               __m512 b0 = _mm512_load_ps(&B_packed[(j16idx * n16 + k16_0) * 256 + dk_0 * 16]);
    750               __m512 b1 = _mm512_load_ps(&B_packed[(j16idx * n16 + k16_1) * 256 + dk_1 * 16]);
    751               __m512 b2 = _mm512_load_ps(&B_packed[(j16idx * n16 + k16_2) * 256 + dk_2 * 16]);
    752               __m512 b3 = _mm512_load_ps(&B_packed[(j16idx * n16 + k16_3) * 256 + dk_3 * 16]);
    753               acc0      = _mm512_fmadd_ps(a0, b0, acc0);
    754               acc1      = _mm512_fmadd_ps(a1, b1, acc1);
    755               acc2      = _mm512_fmadd_ps(a2, b2, acc2);
    756               acc3      = _mm512_fmadd_ps(a3, b3, acc3);
    757             }
    758             for (; k < k_end; k++) {
    759               size_t k16 = k / 16, dk = k % 16;
    760               __m512 a_bcast = _mm512_set1_ps(A[rn + k]);
    761               __m512 b_val   = _mm512_load_ps(&B_packed[(j16idx * n16 + k16) * 256 + dk * 16]);
    762               acc0           = _mm512_fmadd_ps(a_bcast, b_val, acc0);
    763             }
    764           }
    765 
    766           acc0 = _mm512_add_ps(acc0, acc1);
    767           acc2 = _mm512_add_ps(acc2, acc3);
    768           acc0 = _mm512_add_ps(acc0, acc2);
    769 
    770           float tmp[16] __attribute__((aligned(64)));
    771           _mm512_store_ps(tmp, acc0);
    772           for (size_t dj = 0; dj < 16; dj++) acc[li * tj + lj + dj] += tmp[dj];
    773         }
    774         for (size_t j = jj + (tj / 16) * 16; j < j_end; j++) {
    775           size_t lj = j - jj;
    776           for (size_t k = 0; k < n; k++) {
    777             acc[li * tj + lj] += (double)A[rn + k] * (double)B[k * p + j];
    778           }
    779         }
    780       }
    781 
    782       for (size_t i = ii; i < i_end; i++) {
    783         size_t li = i - ii;
    784         for (size_t j = jj; j < j_end; j++) {
    785           size_t lj    = j - jj;
    786           C[i * p + j] = acc[li * tj + lj] * inv_scale;
    787         }
    788       }
    789     }
    790   }
    791 
    792   for (size_t i = 0; i < m; i++) {
    793     for (size_t j = p16 * 16; j < p; j++) {
    794       double sum = 0.0;
    795       for (size_t k = 0; k < n; k++) {
    796         sum += (double)A[i * n + k] * (double)B[k * p + j];
    797       }
    798       C[i * p + j] = (float)(sum * inv_scale);
    799     }
    800   }
    801 
    802   free(B_packed);
    803   return 0;
    804 }
    805 #endif
    806 
    807 static int _matmul_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const float *B, float *C, double scale) {
    808   static int initialized = 0;
    809   if (!initialized) {
    810     matmul_feature_t feat = matmul_get_feature();
    811 #ifdef __AVX512F__
    812     if (feat & MATMUL_FLAG_AVX512)
    813       matmul_f32_f32_f32 = matmul_avx512_f32_f32_f32;
    814     else
    815 #endif
    816 #ifdef __AVX2__
    817         if (feat & MATMUL_FLAG_AVX2)
    818       matmul_f32_f32_f32 = matmul_avx2_f32_f32_f32;
    819     else
    820 #endif
    821       matmul_f32_f32_f32 = matmul_scalar_f32_f32_f32;
    822     initialized = 1;
    823   }
    824   return matmul_f32_f32_f32(m, n, p, A, B, C, scale);
    825 }
    826 
    827 int (*matmul_f32_f32_f32)(size_t, size_t, size_t, const float *, const float *, float *, double) = _matmul_f32_f32_f32;
    828 
    829 /* ========================================================================== */
    830 /* f64_f64_f64 implementations                                                */
    831 /* ========================================================================== */
    832 
    833 int matmul_scalar_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, const double *B, double *C, double scale) {
    834   const size_t ib = 64;
    835   const size_t jb = 64;
    836   const size_t kb = 8;
    837 
    838 #pragma omp parallel for schedule(static)
    839   for (size_t ii = 0; ii < m; ii += ib) {
    840     size_t i_end = (ii + ib < m) ? ii + ib : m;
    841     for (size_t jj = 0; jj < p; jj += jb) {
    842       size_t j_end = (jj + jb < p) ? jj + jb : p;
    843       size_t ti    = i_end - ii;
    844       size_t tj    = j_end - jj;
    845       double acc[64 * 64];
    846       memset(acc, 0, ti * tj * sizeof(double));
    847 
    848       for (size_t kk = 0; kk < n; kk += kb) {
    849         size_t k_end = (kk + kb < n) ? kk + kb : n;
    850         for (size_t i = ii; i < i_end; i++) {
    851           size_t li = i - ii;
    852           for (size_t j = jj; j < j_end; j++) {
    853             size_t lj  = j - jj;
    854             double sum = 0.0;
    855             for (size_t k = kk; k < k_end; k++) {
    856               sum += A[i * n + k] * B[k * p + j];
    857             }
    858             acc[li * tj + lj] += sum;
    859           }
    860         }
    861       }
    862 
    863       for (size_t i = ii; i < i_end; i++) {
    864         size_t li = i - ii;
    865         for (size_t j = jj; j < j_end; j++) {
    866           size_t lj = j - jj;
    867           double v  = acc[li * tj + lj];
    868           if (scale > 1.0) v /= scale;
    869           C[i * p + j] = v;
    870         }
    871       }
    872     }
    873   }
    874   return 0;
    875 }
    876 
    877 #ifdef __AVX2__
    878 static void pack_b_f64(size_t n, size_t p, const double *B, double *B_packed) {
    879   size_t n4 = n / 4;
    880   size_t p4 = p / 4;
    881   for (size_t j4 = 0; j4 < p4; j4++) {
    882     for (size_t k4 = 0; k4 < n4; k4++) {
    883       double *dst = &B_packed[(j4 * n4 + k4) * 16];
    884       for (size_t dk = 0; dk < 4; dk++) {
    885         size_t k = k4 * 4 + dk;
    886         for (size_t dj = 0; dj < 4; dj++) {
    887           size_t j         = j4 * 4 + dj;
    888           dst[dk * 4 + dj] = B[k * p + j];
    889         }
    890       }
    891     }
    892   }
    893 }
    894 
    895 int matmul_avx2_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, const double *B, double *C, double scale) {
    896   const size_t ib = 64;
    897   const size_t jb = 64;
    898   const size_t kb = 8;
    899 
    900   size_t n4 = n / 4;
    901   size_t p4 = p / 4;
    902 
    903   double *B_packed;
    904   if (posix_memalign((void **)&B_packed, 64, p4 * n4 * 16 * sizeof(double)) != 0) return -1;
    905   pack_b_f64(n, p, B, B_packed);
    906 
    907   double inv_scale = (scale > 1.0) ? 1.0 / scale : 1.0;
    908 
    909 #pragma omp parallel for schedule(static)
    910   for (size_t ii = 0; ii < m; ii += ib) {
    911     size_t i_end = (ii + ib < m) ? ii + ib : m;
    912     for (size_t jj = 0; jj < p4 * 4; jj += jb) {
    913       size_t j_end = (jj + jb < p4 * 4) ? jj + jb : p4 * 4;
    914       size_t ti    = i_end - ii;
    915       size_t tj    = j_end - jj;
    916       double acc[64 * 64];
    917       memset(acc, 0, ti * tj * sizeof(double));
    918 
    919       for (size_t i4 = ii; i4 + 4 <= i_end; i4 += 4) {
    920         size_t li0 = i4 - ii;
    921         size_t li1 = li0 + 1;
    922         size_t li2 = li0 + 2;
    923         size_t li3 = li0 + 3;
    924         size_t rn0 = i4 * n;
    925         size_t rn1 = (i4 + 1) * n;
    926         size_t rn2 = (i4 + 2) * n;
    927         size_t rn3 = (i4 + 3) * n;
    928 
    929         for (size_t j = jj; j + 4 <= j_end; j += 4) {
    930           size_t  lj    = j - jj;
    931           size_t  j4idx = j / 4;
    932           __m256d acc00 = _mm256_setzero_pd();
    933           __m256d acc01 = _mm256_setzero_pd();
    934           __m256d acc02 = _mm256_setzero_pd();
    935           __m256d acc03 = _mm256_setzero_pd();
    936           __m256d acc10 = _mm256_setzero_pd();
    937           __m256d acc11 = _mm256_setzero_pd();
    938           __m256d acc12 = _mm256_setzero_pd();
    939           __m256d acc13 = _mm256_setzero_pd();
    940           __m256d acc20 = _mm256_setzero_pd();
    941           __m256d acc21 = _mm256_setzero_pd();
    942           __m256d acc22 = _mm256_setzero_pd();
    943           __m256d acc23 = _mm256_setzero_pd();
    944           __m256d acc30 = _mm256_setzero_pd();
    945           __m256d acc31 = _mm256_setzero_pd();
    946           __m256d acc32 = _mm256_setzero_pd();
    947           __m256d acc33 = _mm256_setzero_pd();
    948 
    949           for (size_t kk = 0; kk < n; kk += kb) {
    950             size_t k_end  = (kk + kb < n) ? kk + kb : n;
    951             size_t k_end4 = kk + (k_end - kk) / 4 * 4;
    952             size_t k      = kk;
    953             for (; k + 4 <= k_end4; k += 4) {
    954               size_t  k4_0 = k / 4, dk_0 = k % 4;
    955               size_t  k4_1 = (k + 1) / 4, dk_1 = (k + 1) % 4;
    956               size_t  k4_2 = (k + 2) / 4, dk_2 = (k + 2) % 4;
    957               size_t  k4_3 = (k + 3) / 4, dk_3 = (k + 3) % 4;
    958               __m256d a0 = _mm256_set1_pd(A[rn0 + k]);
    959               __m256d a1 = _mm256_set1_pd(A[rn0 + k + 1]);
    960               __m256d a2 = _mm256_set1_pd(A[rn0 + k + 2]);
    961               __m256d a3 = _mm256_set1_pd(A[rn0 + k + 3]);
    962               __m256d b0 = _mm256_load_pd(&B_packed[(j4idx * n4 + k4_0) * 16 + dk_0 * 4]);
    963               __m256d b1 = _mm256_load_pd(&B_packed[(j4idx * n4 + k4_1) * 16 + dk_1 * 4]);
    964               __m256d b2 = _mm256_load_pd(&B_packed[(j4idx * n4 + k4_2) * 16 + dk_2 * 4]);
    965               __m256d b3 = _mm256_load_pd(&B_packed[(j4idx * n4 + k4_3) * 16 + dk_3 * 4]);
    966               acc00      = _mm256_fmadd_pd(a0, b0, acc00);
    967               acc01      = _mm256_fmadd_pd(a1, b1, acc01);
    968               acc02      = _mm256_fmadd_pd(a2, b2, acc02);
    969               acc03      = _mm256_fmadd_pd(a3, b3, acc03);
    970               a0         = _mm256_set1_pd(A[rn1 + k]);
    971               a1         = _mm256_set1_pd(A[rn1 + k + 1]);
    972               a2         = _mm256_set1_pd(A[rn1 + k + 2]);
    973               a3         = _mm256_set1_pd(A[rn1 + k + 3]);
    974               acc10      = _mm256_fmadd_pd(a0, b0, acc10);
    975               acc11      = _mm256_fmadd_pd(a1, b1, acc11);
    976               acc12      = _mm256_fmadd_pd(a2, b2, acc12);
    977               acc13      = _mm256_fmadd_pd(a3, b3, acc13);
    978               a0         = _mm256_set1_pd(A[rn2 + k]);
    979               a1         = _mm256_set1_pd(A[rn2 + k + 1]);
    980               a2         = _mm256_set1_pd(A[rn2 + k + 2]);
    981               a3         = _mm256_set1_pd(A[rn2 + k + 3]);
    982               acc20      = _mm256_fmadd_pd(a0, b0, acc20);
    983               acc21      = _mm256_fmadd_pd(a1, b1, acc21);
    984               acc22      = _mm256_fmadd_pd(a2, b2, acc22);
    985               acc23      = _mm256_fmadd_pd(a3, b3, acc23);
    986               a0         = _mm256_set1_pd(A[rn3 + k]);
    987               a1         = _mm256_set1_pd(A[rn3 + k + 1]);
    988               a2         = _mm256_set1_pd(A[rn3 + k + 2]);
    989               a3         = _mm256_set1_pd(A[rn3 + k + 3]);
    990               acc30      = _mm256_fmadd_pd(a0, b0, acc30);
    991               acc31      = _mm256_fmadd_pd(a1, b1, acc31);
    992               acc32      = _mm256_fmadd_pd(a2, b2, acc32);
    993               acc33      = _mm256_fmadd_pd(a3, b3, acc33);
    994             }
    995             for (; k < k_end; k++) {
    996               size_t  k4 = k / 4, dk = k % 4;
    997               __m256d a_bcast = _mm256_set1_pd(A[rn0 + k]);
    998               __m256d b_val   = _mm256_load_pd(&B_packed[(j4idx * n4 + k4) * 16 + dk * 4]);
    999               acc00           = _mm256_fmadd_pd(a_bcast, b_val, acc00);
   1000               a_bcast         = _mm256_set1_pd(A[rn1 + k]);
   1001               acc10           = _mm256_fmadd_pd(a_bcast, b_val, acc10);
   1002               a_bcast         = _mm256_set1_pd(A[rn2 + k]);
   1003               acc20           = _mm256_fmadd_pd(a_bcast, b_val, acc20);
   1004               a_bcast         = _mm256_set1_pd(A[rn3 + k]);
   1005               acc30           = _mm256_fmadd_pd(a_bcast, b_val, acc30);
   1006             }
   1007           }
   1008 
   1009           acc00 = _mm256_add_pd(acc00, acc01);
   1010           acc02 = _mm256_add_pd(acc02, acc03);
   1011           acc00 = _mm256_add_pd(acc00, acc02);
   1012           acc10 = _mm256_add_pd(acc10, acc11);
   1013           acc12 = _mm256_add_pd(acc12, acc13);
   1014           acc10 = _mm256_add_pd(acc10, acc12);
   1015           acc20 = _mm256_add_pd(acc20, acc21);
   1016           acc22 = _mm256_add_pd(acc22, acc23);
   1017           acc20 = _mm256_add_pd(acc20, acc22);
   1018           acc30 = _mm256_add_pd(acc30, acc31);
   1019           acc32 = _mm256_add_pd(acc32, acc33);
   1020           acc30 = _mm256_add_pd(acc30, acc32);
   1021 
   1022           double tmp[4] __attribute__((aligned(32)));
   1023           _mm256_store_pd(tmp, acc00);
   1024           for (size_t dj = 0; dj < 4; dj++) acc[li0 * tj + lj + dj] += tmp[dj];
   1025           _mm256_store_pd(tmp, acc10);
   1026           for (size_t dj = 0; dj < 4; dj++) acc[li1 * tj + lj + dj] += tmp[dj];
   1027           _mm256_store_pd(tmp, acc20);
   1028           for (size_t dj = 0; dj < 4; dj++) acc[li2 * tj + lj + dj] += tmp[dj];
   1029           _mm256_store_pd(tmp, acc30);
   1030           for (size_t dj = 0; dj < 4; dj++) acc[li3 * tj + lj + dj] += tmp[dj];
   1031         }
   1032 
   1033         for (size_t j = jj + (tj / 4) * 4; j < j_end; j++) {
   1034           size_t lj   = j - jj;
   1035           double sum0 = 0.0;
   1036           double sum1 = 0.0;
   1037           double sum2 = 0.0;
   1038           double sum3 = 0.0;
   1039           for (size_t k = 0; k < n; k++) {
   1040             sum0 += A[rn0 + k] * B[k * p + j];
   1041             sum1 += A[rn1 + k] * B[k * p + j];
   1042             sum2 += A[rn2 + k] * B[k * p + j];
   1043             sum3 += A[rn3 + k] * B[k * p + j];
   1044           }
   1045           acc[li0 * tj + lj] += sum0;
   1046           acc[li1 * tj + lj] += sum1;
   1047           acc[li2 * tj + lj] += sum2;
   1048           acc[li3 * tj + lj] += sum3;
   1049         }
   1050       }
   1051 
   1052       for (size_t i = ii + (i_end - ii) / 4 * 4; i < i_end; i++) {
   1053         size_t li = i - ii;
   1054         size_t rn = i * n;
   1055         for (size_t j = jj; j + 4 <= j_end; j += 4) {
   1056           size_t  lj    = j - jj;
   1057           size_t  j4idx = j / 4;
   1058           __m256d acc0  = _mm256_setzero_pd();
   1059           __m256d acc1  = _mm256_setzero_pd();
   1060           __m256d acc2  = _mm256_setzero_pd();
   1061           __m256d acc3  = _mm256_setzero_pd();
   1062 
   1063           for (size_t kk = 0; kk < n; kk += kb) {
   1064             size_t k_end  = (kk + kb < n) ? kk + kb : n;
   1065             size_t k_end4 = kk + (k_end - kk) / 4 * 4;
   1066             size_t k      = kk;
   1067             for (; k + 4 <= k_end4; k += 4) {
   1068               size_t  k4_0 = k / 4, dk_0 = k % 4;
   1069               size_t  k4_1 = (k + 1) / 4, dk_1 = (k + 1) % 4;
   1070               size_t  k4_2 = (k + 2) / 4, dk_2 = (k + 2) % 4;
   1071               size_t  k4_3 = (k + 3) / 4, dk_3 = (k + 3) % 4;
   1072               __m256d a0 = _mm256_set1_pd(A[rn + k]);
   1073               __m256d a1 = _mm256_set1_pd(A[rn + k + 1]);
   1074               __m256d a2 = _mm256_set1_pd(A[rn + k + 2]);
   1075               __m256d a3 = _mm256_set1_pd(A[rn + k + 3]);
   1076               __m256d b0 = _mm256_load_pd(&B_packed[(j4idx * n4 + k4_0) * 16 + dk_0 * 4]);
   1077               __m256d b1 = _mm256_load_pd(&B_packed[(j4idx * n4 + k4_1) * 16 + dk_1 * 4]);
   1078               __m256d b2 = _mm256_load_pd(&B_packed[(j4idx * n4 + k4_2) * 16 + dk_2 * 4]);
   1079               __m256d b3 = _mm256_load_pd(&B_packed[(j4idx * n4 + k4_3) * 16 + dk_3 * 4]);
   1080               acc0       = _mm256_fmadd_pd(a0, b0, acc0);
   1081               acc1       = _mm256_fmadd_pd(a1, b1, acc1);
   1082               acc2       = _mm256_fmadd_pd(a2, b2, acc2);
   1083               acc3       = _mm256_fmadd_pd(a3, b3, acc3);
   1084             }
   1085             for (; k < k_end; k++) {
   1086               size_t  k4 = k / 4, dk = k % 4;
   1087               __m256d a_bcast = _mm256_set1_pd(A[rn + k]);
   1088               __m256d b_val   = _mm256_load_pd(&B_packed[(j4idx * n4 + k4) * 16 + dk * 4]);
   1089               acc0            = _mm256_fmadd_pd(a_bcast, b_val, acc0);
   1090             }
   1091           }
   1092 
   1093           acc0 = _mm256_add_pd(acc0, acc1);
   1094           acc2 = _mm256_add_pd(acc2, acc3);
   1095           acc0 = _mm256_add_pd(acc0, acc2);
   1096 
   1097           double tmp[4] __attribute__((aligned(32)));
   1098           _mm256_store_pd(tmp, acc0);
   1099           for (size_t dj = 0; dj < 4; dj++) acc[li * tj + lj + dj] += tmp[dj];
   1100         }
   1101         for (size_t j = jj + (tj / 4) * 4; j < j_end; j++) {
   1102           size_t lj = j - jj;
   1103           for (size_t k = 0; k < n; k++) {
   1104             acc[li * tj + lj] += A[rn + k] * B[k * p + j];
   1105           }
   1106         }
   1107       }
   1108 
   1109       for (size_t i = ii; i < i_end; i++) {
   1110         size_t li = i - ii;
   1111         for (size_t j = jj; j < j_end; j++) {
   1112           size_t lj    = j - jj;
   1113           C[i * p + j] = acc[li * tj + lj] * inv_scale;
   1114         }
   1115       }
   1116     }
   1117   }
   1118 
   1119   for (size_t i = 0; i < m; i++) {
   1120     for (size_t j = p4 * 4; j < p; j++) {
   1121       double sum = 0.0;
   1122       for (size_t k = 0; k < n; k++) {
   1123         sum += A[i * n + k] * B[k * p + j];
   1124       }
   1125       C[i * p + j] = sum * inv_scale;
   1126     }
   1127   }
   1128 
   1129   free(B_packed);
   1130   return 0;
   1131 }
   1132 #endif
   1133 
   1134 #ifdef __AVX512F__
   1135 static void pack_b_f64_512(size_t n, size_t p, const double *B, double *B_packed) {
   1136   size_t n8 = n / 8;
   1137   size_t p8 = p / 8;
   1138   for (size_t j8 = 0; j8 < p8; j8++) {
   1139     for (size_t k8 = 0; k8 < n8; k8++) {
   1140       double *dst = &B_packed[(j8 * n8 + k8) * 64];
   1141       for (size_t dk = 0; dk < 8; dk++) {
   1142         size_t k = k8 * 8 + dk;
   1143         for (size_t dj = 0; dj < 8; dj++) {
   1144           size_t j         = j8 * 8 + dj;
   1145           dst[dk * 8 + dj] = B[k * p + j];
   1146         }
   1147       }
   1148     }
   1149   }
   1150 }
   1151 
   1152 int matmul_avx512_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, const double *B, double *C, double scale) {
   1153   const size_t ib = 64;
   1154   const size_t jb = 64;
   1155   const size_t kb = 8;
   1156 
   1157   size_t n8 = n / 8;
   1158   size_t p8 = p / 8;
   1159 
   1160   double *B_packed;
   1161   if (posix_memalign((void **)&B_packed, 64, p8 * n8 * 64 * sizeof(double)) != 0) return -1;
   1162   pack_b_f64_512(n, p, B, B_packed);
   1163 
   1164   double inv_scale = (scale > 1.0) ? 1.0 / scale : 1.0;
   1165 
   1166 #pragma omp parallel for schedule(static)
   1167   for (size_t ii = 0; ii < m; ii += ib) {
   1168     size_t i_end = (ii + ib < m) ? ii + ib : m;
   1169     for (size_t jj = 0; jj < p8 * 8; jj += jb) {
   1170       size_t j_end = (jj + jb < p8 * 8) ? jj + jb : p8 * 8;
   1171       size_t ti    = i_end - ii;
   1172       size_t tj    = j_end - jj;
   1173       double acc[64 * 64];
   1174       memset(acc, 0, ti * tj * sizeof(double));
   1175 
   1176       for (size_t i4 = ii; i4 + 4 <= i_end; i4 += 4) {
   1177         size_t li0 = i4 - ii;
   1178         size_t li1 = li0 + 1;
   1179         size_t li2 = li0 + 2;
   1180         size_t li3 = li0 + 3;
   1181         size_t rn0 = i4 * n;
   1182         size_t rn1 = (i4 + 1) * n;
   1183         size_t rn2 = (i4 + 2) * n;
   1184         size_t rn3 = (i4 + 3) * n;
   1185 
   1186         for (size_t j = jj; j + 8 <= j_end; j += 8) {
   1187           size_t  lj    = j - jj;
   1188           size_t  j8idx = j / 8;
   1189           __m512d acc00 = _mm512_setzero_pd();
   1190           __m512d acc01 = _mm512_setzero_pd();
   1191           __m512d acc02 = _mm512_setzero_pd();
   1192           __m512d acc03 = _mm512_setzero_pd();
   1193           __m512d acc10 = _mm512_setzero_pd();
   1194           __m512d acc11 = _mm512_setzero_pd();
   1195           __m512d acc12 = _mm512_setzero_pd();
   1196           __m512d acc13 = _mm512_setzero_pd();
   1197           __m512d acc20 = _mm512_setzero_pd();
   1198           __m512d acc21 = _mm512_setzero_pd();
   1199           __m512d acc22 = _mm512_setzero_pd();
   1200           __m512d acc23 = _mm512_setzero_pd();
   1201           __m512d acc30 = _mm512_setzero_pd();
   1202           __m512d acc31 = _mm512_setzero_pd();
   1203           __m512d acc32 = _mm512_setzero_pd();
   1204           __m512d acc33 = _mm512_setzero_pd();
   1205 
   1206           for (size_t kk = 0; kk < n; kk += kb) {
   1207             size_t k_end  = (kk + kb < n) ? kk + kb : n;
   1208             size_t k_end8 = kk + (k_end - kk) / 8 * 8;
   1209             size_t k      = kk;
   1210             for (; k + 4 <= k_end8; k += 4) {
   1211               size_t  k8_0 = k / 8, dk_0 = k % 8;
   1212               size_t  k8_1 = (k + 1) / 8, dk_1 = (k + 1) % 8;
   1213               size_t  k8_2 = (k + 2) / 8, dk_2 = (k + 2) % 8;
   1214               size_t  k8_3 = (k + 3) / 8, dk_3 = (k + 3) % 8;
   1215               __m512d a0 = _mm512_set1_pd(A[rn0 + k]);
   1216               __m512d a1 = _mm512_set1_pd(A[rn0 + k + 1]);
   1217               __m512d a2 = _mm512_set1_pd(A[rn0 + k + 2]);
   1218               __m512d a3 = _mm512_set1_pd(A[rn0 + k + 3]);
   1219               __m512d b0 = _mm512_load_pd(&B_packed[(j8idx * n8 + k8_0) * 64 + dk_0 * 8]);
   1220               __m512d b1 = _mm512_load_pd(&B_packed[(j8idx * n8 + k8_1) * 64 + dk_1 * 8]);
   1221               __m512d b2 = _mm512_load_pd(&B_packed[(j8idx * n8 + k8_2) * 64 + dk_2 * 8]);
   1222               __m512d b3 = _mm512_load_pd(&B_packed[(j8idx * n8 + k8_3) * 64 + dk_3 * 8]);
   1223               acc00      = _mm512_fmadd_pd(a0, b0, acc00);
   1224               acc01      = _mm512_fmadd_pd(a1, b1, acc01);
   1225               acc02      = _mm512_fmadd_pd(a2, b2, acc02);
   1226               acc03      = _mm512_fmadd_pd(a3, b3, acc03);
   1227               a0         = _mm512_set1_pd(A[rn1 + k]);
   1228               a1         = _mm512_set1_pd(A[rn1 + k + 1]);
   1229               a2         = _mm512_set1_pd(A[rn1 + k + 2]);
   1230               a3         = _mm512_set1_pd(A[rn1 + k + 3]);
   1231               acc10      = _mm512_fmadd_pd(a0, b0, acc10);
   1232               acc11      = _mm512_fmadd_pd(a1, b1, acc11);
   1233               acc12      = _mm512_fmadd_pd(a2, b2, acc12);
   1234               acc13      = _mm512_fmadd_pd(a3, b3, acc13);
   1235               a0         = _mm512_set1_pd(A[rn2 + k]);
   1236               a1         = _mm512_set1_pd(A[rn2 + k + 1]);
   1237               a2         = _mm512_set1_pd(A[rn2 + k + 2]);
   1238               a3         = _mm512_set1_pd(A[rn2 + k + 3]);
   1239               acc20      = _mm512_fmadd_pd(a0, b0, acc20);
   1240               acc21      = _mm512_fmadd_pd(a1, b1, acc21);
   1241               acc22      = _mm512_fmadd_pd(a2, b2, acc22);
   1242               acc23      = _mm512_fmadd_pd(a3, b3, acc23);
   1243               a0         = _mm512_set1_pd(A[rn3 + k]);
   1244               a1         = _mm512_set1_pd(A[rn3 + k + 1]);
   1245               a2         = _mm512_set1_pd(A[rn3 + k + 2]);
   1246               a3         = _mm512_set1_pd(A[rn3 + k + 3]);
   1247               acc30      = _mm512_fmadd_pd(a0, b0, acc30);
   1248               acc31      = _mm512_fmadd_pd(a1, b1, acc31);
   1249               acc32      = _mm512_fmadd_pd(a2, b2, acc32);
   1250               acc33      = _mm512_fmadd_pd(a3, b3, acc33);
   1251             }
   1252             for (; k < k_end; k++) {
   1253               size_t  k8 = k / 8, dk = k % 8;
   1254               __m512d a_bcast = _mm512_set1_pd(A[rn0 + k]);
   1255               __m512d b_val   = _mm512_load_pd(&B_packed[(j8idx * n8 + k8) * 64 + dk * 8]);
   1256               acc00           = _mm512_fmadd_pd(a_bcast, b_val, acc00);
   1257               a_bcast         = _mm512_set1_pd(A[rn1 + k]);
   1258               acc10           = _mm512_fmadd_pd(a_bcast, b_val, acc10);
   1259               a_bcast         = _mm512_set1_pd(A[rn2 + k]);
   1260               acc20           = _mm512_fmadd_pd(a_bcast, b_val, acc20);
   1261               a_bcast         = _mm512_set1_pd(A[rn3 + k]);
   1262               acc30           = _mm512_fmadd_pd(a_bcast, b_val, acc30);
   1263             }
   1264           }
   1265 
   1266           acc00 = _mm512_add_pd(acc00, acc01);
   1267           acc02 = _mm512_add_pd(acc02, acc03);
   1268           acc00 = _mm512_add_pd(acc00, acc02);
   1269           acc10 = _mm512_add_pd(acc10, acc11);
   1270           acc12 = _mm512_add_pd(acc12, acc13);
   1271           acc10 = _mm512_add_pd(acc10, acc12);
   1272           acc20 = _mm512_add_pd(acc20, acc21);
   1273           acc22 = _mm512_add_pd(acc22, acc23);
   1274           acc20 = _mm512_add_pd(acc20, acc22);
   1275           acc30 = _mm512_add_pd(acc30, acc31);
   1276           acc32 = _mm512_add_pd(acc32, acc33);
   1277           acc30 = _mm512_add_pd(acc30, acc32);
   1278 
   1279           double tmp[8] __attribute__((aligned(64)));
   1280           _mm512_store_pd(tmp, acc00);
   1281           for (size_t dj = 0; dj < 8; dj++) acc[li0 * tj + lj + dj] += tmp[dj];
   1282           _mm512_store_pd(tmp, acc10);
   1283           for (size_t dj = 0; dj < 8; dj++) acc[li1 * tj + lj + dj] += tmp[dj];
   1284           _mm512_store_pd(tmp, acc20);
   1285           for (size_t dj = 0; dj < 8; dj++) acc[li2 * tj + lj + dj] += tmp[dj];
   1286           _mm512_store_pd(tmp, acc30);
   1287           for (size_t dj = 0; dj < 8; dj++) acc[li3 * tj + lj + dj] += tmp[dj];
   1288         }
   1289 
   1290         for (size_t j = jj + (tj / 8) * 8; j < j_end; j++) {
   1291           size_t lj   = j - jj;
   1292           double sum0 = 0.0;
   1293           double sum1 = 0.0;
   1294           double sum2 = 0.0;
   1295           double sum3 = 0.0;
   1296           for (size_t k = 0; k < n; k++) {
   1297             sum0 += A[rn0 + k] * B[k * p + j];
   1298             sum1 += A[rn1 + k] * B[k * p + j];
   1299             sum2 += A[rn2 + k] * B[k * p + j];
   1300             sum3 += A[rn3 + k] * B[k * p + j];
   1301           }
   1302           acc[li0 * tj + lj] += sum0;
   1303           acc[li1 * tj + lj] += sum1;
   1304           acc[li2 * tj + lj] += sum2;
   1305           acc[li3 * tj + lj] += sum3;
   1306         }
   1307       }
   1308 
   1309       for (size_t i = ii + (i_end - ii) / 4 * 4; i < i_end; i++) {
   1310         size_t li = i - ii;
   1311         size_t rn = i * n;
   1312         for (size_t j = jj; j + 8 <= j_end; j += 8) {
   1313           size_t  lj    = j - jj;
   1314           size_t  j8idx = j / 8;
   1315           __m512d acc0  = _mm512_setzero_pd();
   1316           __m512d acc1  = _mm512_setzero_pd();
   1317           __m512d acc2  = _mm512_setzero_pd();
   1318           __m512d acc3  = _mm512_setzero_pd();
   1319 
   1320           for (size_t kk = 0; kk < n; kk += kb) {
   1321             size_t k_end  = (kk + kb < n) ? kk + kb : n;
   1322             size_t k_end8 = kk + (k_end - kk) / 8 * 8;
   1323             size_t k      = kk;
   1324             for (; k + 4 <= k_end8; k += 4) {
   1325               size_t  k8_0 = k / 8, dk_0 = k % 8;
   1326               size_t  k8_1 = (k + 1) / 8, dk_1 = (k + 1) % 8;
   1327               size_t  k8_2 = (k + 2) / 8, dk_2 = (k + 2) % 8;
   1328               size_t  k8_3 = (k + 3) / 8, dk_3 = (k + 3) % 8;
   1329               __m512d a0 = _mm512_set1_pd(A[rn + k]);
   1330               __m512d a1 = _mm512_set1_pd(A[rn + k + 1]);
   1331               __m512d a2 = _mm512_set1_pd(A[rn + k + 2]);
   1332               __m512d a3 = _mm512_set1_pd(A[rn + k + 3]);
   1333               __m512d b0 = _mm512_load_pd(&B_packed[(j8idx * n8 + k8_0) * 64 + dk_0 * 8]);
   1334               __m512d b1 = _mm512_load_pd(&B_packed[(j8idx * n8 + k8_1) * 64 + dk_1 * 8]);
   1335               __m512d b2 = _mm512_load_pd(&B_packed[(j8idx * n8 + k8_2) * 64 + dk_2 * 8]);
   1336               __m512d b3 = _mm512_load_pd(&B_packed[(j8idx * n8 + k8_3) * 64 + dk_3 * 8]);
   1337               acc0       = _mm512_fmadd_pd(a0, b0, acc0);
   1338               acc1       = _mm512_fmadd_pd(a1, b1, acc1);
   1339               acc2       = _mm512_fmadd_pd(a2, b2, acc2);
   1340               acc3       = _mm512_fmadd_pd(a3, b3, acc3);
   1341             }
   1342             for (; k < k_end; k++) {
   1343               size_t  k8 = k / 8, dk = k % 8;
   1344               __m512d a_bcast = _mm512_set1_pd(A[rn + k]);
   1345               __m512d b_val   = _mm512_load_pd(&B_packed[(j8idx * n8 + k8) * 64 + dk * 8]);
   1346               acc0            = _mm512_fmadd_pd(a_bcast, b_val, acc0);
   1347             }
   1348           }
   1349 
   1350           acc0 = _mm512_add_pd(acc0, acc1);
   1351           acc2 = _mm512_add_pd(acc2, acc3);
   1352           acc0 = _mm512_add_pd(acc0, acc2);
   1353 
   1354           double tmp[8] __attribute__((aligned(64)));
   1355           _mm512_store_pd(tmp, acc0);
   1356           for (size_t dj = 0; dj < 8; dj++) acc[li * tj + lj + dj] += tmp[dj];
   1357         }
   1358         for (size_t j = jj + (tj / 8) * 8; j < j_end; j++) {
   1359           size_t lj = j - jj;
   1360           for (size_t k = 0; k < n; k++) {
   1361             acc[li * tj + lj] += A[rn + k] * B[k * p + j];
   1362           }
   1363         }
   1364       }
   1365 
   1366       for (size_t i = ii; i < i_end; i++) {
   1367         size_t li = i - ii;
   1368         for (size_t j = jj; j < j_end; j++) {
   1369           size_t lj    = j - jj;
   1370           C[i * p + j] = acc[li * tj + lj] * inv_scale;
   1371         }
   1372       }
   1373     }
   1374   }
   1375 
   1376   for (size_t i = 0; i < m; i++) {
   1377     for (size_t j = p8 * 8; j < p; j++) {
   1378       double sum = 0.0;
   1379       for (size_t k = 0; k < n; k++) {
   1380         sum += A[i * n + k] * B[k * p + j];
   1381       }
   1382       C[i * p + j] = sum * inv_scale;
   1383     }
   1384   }
   1385 
   1386   free(B_packed);
   1387   return 0;
   1388 }
   1389 #endif
   1390 
   1391 static int _matmul_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, const double *B, double *C,
   1392                                double scale) {
   1393   static int initialized = 0;
   1394   if (!initialized) {
   1395     matmul_feature_t feat = matmul_get_feature();
   1396 #ifdef __AVX512F__
   1397     if (feat & MATMUL_FLAG_AVX512)
   1398       matmul_f64_f64_f64 = matmul_avx512_f64_f64_f64;
   1399     else
   1400 #endif
   1401 #ifdef __AVX2__
   1402         if (feat & MATMUL_FLAG_AVX2)
   1403       matmul_f64_f64_f64 = matmul_avx2_f64_f64_f64;
   1404     else
   1405 #endif
   1406       matmul_f64_f64_f64 = matmul_scalar_f64_f64_f64;
   1407     initialized = 1;
   1408   }
   1409   return matmul_f64_f64_f64(m, n, p, A, B, C, scale);
   1410 }
   1411 
   1412 int (*matmul_f64_f64_f64)(size_t, size_t, size_t, const double *, const double *, double *,
   1413                           double) = _matmul_f64_f64_f64;