matmul.c (55943B)
1 /* 2 * Copyright (c) 2026 finwo 3 * 4 * Permission is hereby granted, free of charge, to any person obtaining a copy of 5 * this software and associated documentation files (the "Software"), to use, copy, 6 * modify, and distribute the Software, subject to the following conditions: 7 * 8 * 1. Redistributions of source code must retain the above copyright notice, this 9 * list of conditions, and the following disclaimer. 10 * 11 * 2. Redistributions in binary form, or any public offering of the Software 12 * (including hosted or managed services), must reproduce the above copyright 13 * notice, this list of conditions, and the following disclaimer in the 14 * documentation and/or other materials provided. 15 * 16 * 3. Any redistribution or public offering of the Software must clearly attribute 17 * the Software to the original copyright holder, reference this License, and 18 * include a link to the official project repository or website. 19 * 20 * 4. The Software may not be renamed, rebranded, or marketed in a manner that 21 * implies it is an independent or proprietary product. Derivative works must 22 * clearly state that they are based on the Software. 23 * 24 * 5. Modifications to copies of the Software must carry prominent notices stating 25 * that changes were made, the nature of the modifications, and the date of the 26 * modifications. 27 * 28 * Any violation of these conditions terminates the permissions granted herein. 29 * 30 * THIS SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 31 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 32 * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE COPYRIGHT 33 * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, 34 * WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 35 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 36 */ 37 38 #define _GNU_SOURCE 39 #include "matmul.h" 40 41 #include <stdio.h> 42 #include <stdlib.h> 43 #include <string.h> 44 45 #ifdef __AVX2__ 46 #include <immintrin.h> 47 #endif 48 49 #ifdef __AVX512F__ 50 #include <immintrin.h> 51 #endif 52 53 #ifdef __AVXVNNI__ 54 #include <immintrin.h> 55 #endif 56 57 #ifdef __AVX512VNNI__ 58 #include <immintrin.h> 59 #endif 60 61 #ifdef _OPENMP 62 #include <omp.h> 63 #endif 64 65 #define MATMUL_FLAG_SCALAR (1 << 0) 66 #define MATMUL_FLAG_AVX2 (1 << 1) 67 #define MATMUL_FLAG_AVXVNNI (1 << 2) 68 #define MATMUL_FLAG_AVX512 (1 << 3) 69 #define MATMUL_FLAG_AVX512VNNI (1 << 4) 70 71 typedef uint32_t matmul_feature_t; 72 73 static matmul_feature_t g_feature = 0; 74 static int g_initialized = 0; 75 76 static void init_feature(void) { 77 g_feature = MATMUL_FLAG_SCALAR; 78 #ifdef __AVX512VNNI__ 79 if (__builtin_cpu_supports("avx512vnni")) g_feature |= MATMUL_FLAG_AVX512VNNI; 80 #endif 81 #ifdef __AVX512F__ 82 if (__builtin_cpu_supports("avx512f")) g_feature |= MATMUL_FLAG_AVX512; 83 #endif 84 #ifdef __AVXVNNI__ 85 if (__builtin_cpu_supports("avxvnni")) g_feature |= MATMUL_FLAG_AVXVNNI; 86 #endif 87 #ifdef __AVX2__ 88 if (__builtin_cpu_supports("avx2")) g_feature |= MATMUL_FLAG_AVX2; 89 #endif 90 } 91 92 matmul_feature_t matmul_get_feature(void) { 93 if (!g_initialized) { 94 init_feature(); 95 g_initialized = 1; 96 } 97 return g_feature; 98 } 99 100 const char *matmul_get_feature_name(matmul_feature_t feat) { 101 if (feat & MATMUL_FLAG_AVX512VNNI) return "avx512vnni"; 102 if (feat & MATMUL_FLAG_AVX512) return "avx512"; 103 if (feat & MATMUL_FLAG_AVXVNNI) return "avxvnni"; 104 if (feat & MATMUL_FLAG_AVX2) return "avx2"; 105 if (feat & MATMUL_FLAG_SCALAR) return "scalar"; 106 return "unknown"; 107 } 108 109 int matmul_scalar_u8_i8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, uint8_t *C, double scale) { 110 (void)scale; 111 const size_t ib = 64; 112 const size_t jb = 64; 113 const size_t kb = 32; 114 115 #pragma omp parallel for schedule(static) 116 for (size_t ii = 0; ii < m; ii += ib) { 117 size_t i_end = (ii + ib < m) ? ii + ib : m; 118 for (size_t jj = 0; jj < p; jj += jb) { 119 size_t j_end = (jj + jb < p) ? jj + jb : p; 120 size_t ti = i_end - ii; 121 size_t tj = j_end - jj; 122 int32_t acc[64 * 64]; 123 memset(acc, 0, ti * tj * sizeof(int32_t)); 124 125 for (size_t kk = 0; kk < n; kk += kb) { 126 size_t k_end = (kk + kb < n) ? kk + kb : n; 127 for (size_t i = ii; i < i_end; i++) { 128 size_t li = i - ii; 129 for (size_t j = jj; j < j_end; j++) { 130 size_t lj = j - jj; 131 int sum = 0; 132 for (size_t k = kk; k < k_end; k++) { 133 sum += (int)A[i * n + k] * (int)B[k * p + j]; 134 } 135 acc[li * tj + lj] += sum; 136 } 137 } 138 } 139 140 for (size_t i = ii; i < i_end; i++) { 141 size_t li = i - ii; 142 for (size_t j = jj; j < j_end; j++) { 143 size_t lj = j - jj; 144 int v = acc[li * tj + lj]; 145 if (scale > 1.0) v = (int)(v / scale); 146 if (v > 255) 147 v = 255; 148 else if (v < 0) 149 v = 0; 150 C[i * p + j] = (uint8_t)v; 151 } 152 } 153 } 154 } 155 return 0; 156 } 157 158 #ifdef __AVX512VNNI__ 159 static void pack_b_i8(size_t n, size_t p, const int8_t *B, int8_t *B_packed) { 160 size_t n4 = n / 4; 161 size_t p16 = p / 16; 162 for (size_t j16 = 0; j16 < p16; j16++) { 163 for (size_t k4 = 0; k4 < n4; k4++) { 164 int8_t *dst = &B_packed[(j16 * n4 + k4) * 64]; 165 for (size_t dj = 0; dj < 16; dj++) { 166 size_t j = j16 * 16 + dj; 167 for (size_t dk = 0; dk < 4; dk++) { 168 dst[dj * 4 + dk] = B[(k4 * 4 + dk) * p + j]; 169 } 170 } 171 } 172 } 173 } 174 175 int matmul_avx512vnni_u8_i8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, uint8_t *C, 176 double scale) { 177 (void)scale; 178 179 size_t n4 = n / 4; 180 size_t p16 = p / 16; 181 182 int8_t *B_packed; 183 if (posix_memalign((void **)&B_packed, 64, p16 * n4 * 64) != 0) return -1; 184 pack_b_i8(n, p, B, B_packed); 185 186 const uint32_t *A32 = (const uint32_t *)A; 187 188 #pragma omp parallel for schedule(static) 189 for (size_t i = 0; i < m; i++) { 190 for (size_t j16 = 0; j16 < p16; j16++) { 191 __m512i result = _mm512_setzero_si512(); 192 for (size_t k4 = 0; k4 < n4; k4++) { 193 __m512i a_val = _mm512_set1_epi32(A32[i * n4 + k4]); 194 __m512i b_val = _mm512_load_si512((__m512i const *)&B_packed[(j16 * n4 + k4) * 64]); 195 result = _mm512_dpbusd_epi32(result, a_val, b_val); 196 } 197 int32_t tmp[16] __attribute__((aligned(64))); 198 _mm512_store_si512(tmp, result); 199 for (size_t dj = 0; dj < 16; dj++) { 200 int32_t v = tmp[dj]; 201 if (scale > 1.0) v = (int32_t)(v / scale); 202 if (v > 255) 203 v = 255; 204 else if (v < 0) 205 v = 0; 206 C[i * p + j16 * 16 + dj] = (uint8_t)v; 207 } 208 } 209 for (size_t j = p16 * 16; j < p; j++) { 210 int32_t sum = 0; 211 for (size_t k = 0; k < n; k++) { 212 sum += (int)A[i * n + k] * (int)B[k * p + j]; 213 } 214 if (scale > 1.0) sum = (int32_t)(sum / scale); 215 if (sum > 255) 216 sum = 255; 217 else if (sum < 0) 218 sum = 0; 219 C[i * p + j] = (uint8_t)sum; 220 } 221 } 222 223 free(B_packed); 224 return 0; 225 } 226 #endif 227 228 static int _matmul_u8_i8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, uint8_t *C, double scale) { 229 static int initialized = 0; 230 if (!initialized) { 231 matmul_feature_t feat = matmul_get_feature(); 232 #ifdef __AVX512VNNI__ 233 if (feat & MATMUL_FLAG_AVX512VNNI) 234 matmul_u8_i8_u8 = matmul_avx512vnni_u8_i8_u8; 235 else 236 #endif 237 matmul_u8_i8_u8 = matmul_scalar_u8_i8_u8; 238 initialized = 1; 239 } 240 return matmul_u8_i8_u8(m, n, p, A, B, C, scale); 241 } 242 243 int (*matmul_u8_i8_u8)(size_t, size_t, size_t, const uint8_t *, const int8_t *, uint8_t *, double) = _matmul_u8_i8_u8; 244 245 /* ========================================================================== */ 246 /* f32_f32_f32 implementations */ 247 /* ========================================================================== */ 248 249 int matmul_scalar_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const float *B, float *C, double scale) { 250 const size_t ib = 64; 251 const size_t jb = 64; 252 const size_t kb = 16; 253 254 #pragma omp parallel for schedule(static) 255 for (size_t ii = 0; ii < m; ii += ib) { 256 size_t i_end = (ii + ib < m) ? ii + ib : m; 257 for (size_t jj = 0; jj < p; jj += jb) { 258 size_t j_end = (jj + jb < p) ? jj + jb : p; 259 size_t ti = i_end - ii; 260 size_t tj = j_end - jj; 261 double acc[64 * 64]; 262 memset(acc, 0, ti * tj * sizeof(double)); 263 264 for (size_t kk = 0; kk < n; kk += kb) { 265 size_t k_end = (kk + kb < n) ? kk + kb : n; 266 for (size_t i = ii; i < i_end; i++) { 267 size_t li = i - ii; 268 for (size_t j = jj; j < j_end; j++) { 269 size_t lj = j - jj; 270 double sum = 0.0; 271 for (size_t k = kk; k < k_end; k++) { 272 sum += (double)A[i * n + k] * (double)B[k * p + j]; 273 } 274 acc[li * tj + lj] += sum; 275 } 276 } 277 } 278 279 for (size_t i = ii; i < i_end; i++) { 280 size_t li = i - ii; 281 for (size_t j = jj; j < j_end; j++) { 282 size_t lj = j - jj; 283 double v = acc[li * tj + lj]; 284 if (scale > 1.0) v /= scale; 285 C[i * p + j] = (float)v; 286 } 287 } 288 } 289 } 290 return 0; 291 } 292 293 #ifdef __AVX2__ 294 static void pack_b_f32(size_t n, size_t p, const float *B, float *B_packed) { 295 size_t n8 = n / 8; 296 size_t p8 = p / 8; 297 for (size_t j8 = 0; j8 < p8; j8++) { 298 for (size_t k8 = 0; k8 < n8; k8++) { 299 float *dst = &B_packed[(j8 * n8 + k8) * 64]; 300 for (size_t dk = 0; dk < 8; dk++) { 301 size_t k = k8 * 8 + dk; 302 for (size_t dj = 0; dj < 8; dj++) { 303 size_t j = j8 * 8 + dj; 304 dst[dk * 8 + dj] = B[k * p + j]; 305 } 306 } 307 } 308 } 309 } 310 311 int matmul_avx2_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const float *B, float *C, double scale) { 312 const size_t ib = 64; 313 const size_t jb = 64; 314 const size_t kb = 16; 315 316 size_t n8 = n / 8; 317 size_t p8 = p / 8; 318 319 float *B_packed; 320 if (posix_memalign((void **)&B_packed, 64, p8 * n8 * 64 * sizeof(float)) != 0) return -1; 321 pack_b_f32(n, p, B, B_packed); 322 323 float inv_scale = (scale > 1.0) ? 1.0f / (float)scale : 1.0f; 324 325 #pragma omp parallel for schedule(static) 326 for (size_t ii = 0; ii < m; ii += ib) { 327 size_t i_end = (ii + ib < m) ? ii + ib : m; 328 for (size_t jj = 0; jj < p8 * 8; jj += jb) { 329 size_t j_end = (jj + jb < p8 * 8) ? jj + jb : p8 * 8; 330 size_t ti = i_end - ii; 331 size_t tj = j_end - jj; 332 float acc[64 * 64]; 333 memset(acc, 0, ti * tj * sizeof(float)); 334 335 for (size_t i4 = ii; i4 + 4 <= i_end; i4 += 4) { 336 size_t li0 = i4 - ii; 337 size_t li1 = li0 + 1; 338 size_t li2 = li0 + 2; 339 size_t li3 = li0 + 3; 340 size_t rn0 = i4 * n; 341 size_t rn1 = (i4 + 1) * n; 342 size_t rn2 = (i4 + 2) * n; 343 size_t rn3 = (i4 + 3) * n; 344 345 for (size_t j = jj; j + 8 <= j_end; j += 8) { 346 size_t lj = j - jj; 347 size_t j8idx = j / 8; 348 __m256 acc00 = _mm256_setzero_ps(); 349 __m256 acc01 = _mm256_setzero_ps(); 350 __m256 acc02 = _mm256_setzero_ps(); 351 __m256 acc03 = _mm256_setzero_ps(); 352 __m256 acc10 = _mm256_setzero_ps(); 353 __m256 acc11 = _mm256_setzero_ps(); 354 __m256 acc12 = _mm256_setzero_ps(); 355 __m256 acc13 = _mm256_setzero_ps(); 356 __m256 acc20 = _mm256_setzero_ps(); 357 __m256 acc21 = _mm256_setzero_ps(); 358 __m256 acc22 = _mm256_setzero_ps(); 359 __m256 acc23 = _mm256_setzero_ps(); 360 __m256 acc30 = _mm256_setzero_ps(); 361 __m256 acc31 = _mm256_setzero_ps(); 362 __m256 acc32 = _mm256_setzero_ps(); 363 __m256 acc33 = _mm256_setzero_ps(); 364 365 for (size_t kk = 0; kk < n; kk += kb) { 366 size_t k_end = (kk + kb < n) ? kk + kb : n; 367 size_t k_end8 = kk + (k_end - kk) / 8 * 8; 368 size_t k = kk; 369 for (; k + 4 <= k_end8; k += 4) { 370 size_t k8_0 = k / 8, dk_0 = k % 8; 371 size_t k8_1 = (k + 1) / 8, dk_1 = (k + 1) % 8; 372 size_t k8_2 = (k + 2) / 8, dk_2 = (k + 2) % 8; 373 size_t k8_3 = (k + 3) / 8, dk_3 = (k + 3) % 8; 374 __m256 a0 = _mm256_set1_ps(A[rn0 + k]); 375 __m256 a1 = _mm256_set1_ps(A[rn0 + k + 1]); 376 __m256 a2 = _mm256_set1_ps(A[rn0 + k + 2]); 377 __m256 a3 = _mm256_set1_ps(A[rn0 + k + 3]); 378 __m256 b0 = _mm256_load_ps(&B_packed[(j8idx * n8 + k8_0) * 64 + dk_0 * 8]); 379 __m256 b1 = _mm256_load_ps(&B_packed[(j8idx * n8 + k8_1) * 64 + dk_1 * 8]); 380 __m256 b2 = _mm256_load_ps(&B_packed[(j8idx * n8 + k8_2) * 64 + dk_2 * 8]); 381 __m256 b3 = _mm256_load_ps(&B_packed[(j8idx * n8 + k8_3) * 64 + dk_3 * 8]); 382 acc00 = _mm256_fmadd_ps(a0, b0, acc00); 383 acc01 = _mm256_fmadd_ps(a1, b1, acc01); 384 acc02 = _mm256_fmadd_ps(a2, b2, acc02); 385 acc03 = _mm256_fmadd_ps(a3, b3, acc03); 386 a0 = _mm256_set1_ps(A[rn1 + k]); 387 a1 = _mm256_set1_ps(A[rn1 + k + 1]); 388 a2 = _mm256_set1_ps(A[rn1 + k + 2]); 389 a3 = _mm256_set1_ps(A[rn1 + k + 3]); 390 acc10 = _mm256_fmadd_ps(a0, b0, acc10); 391 acc11 = _mm256_fmadd_ps(a1, b1, acc11); 392 acc12 = _mm256_fmadd_ps(a2, b2, acc12); 393 acc13 = _mm256_fmadd_ps(a3, b3, acc13); 394 a0 = _mm256_set1_ps(A[rn2 + k]); 395 a1 = _mm256_set1_ps(A[rn2 + k + 1]); 396 a2 = _mm256_set1_ps(A[rn2 + k + 2]); 397 a3 = _mm256_set1_ps(A[rn2 + k + 3]); 398 acc20 = _mm256_fmadd_ps(a0, b0, acc20); 399 acc21 = _mm256_fmadd_ps(a1, b1, acc21); 400 acc22 = _mm256_fmadd_ps(a2, b2, acc22); 401 acc23 = _mm256_fmadd_ps(a3, b3, acc23); 402 a0 = _mm256_set1_ps(A[rn3 + k]); 403 a1 = _mm256_set1_ps(A[rn3 + k + 1]); 404 a2 = _mm256_set1_ps(A[rn3 + k + 2]); 405 a3 = _mm256_set1_ps(A[rn3 + k + 3]); 406 acc30 = _mm256_fmadd_ps(a0, b0, acc30); 407 acc31 = _mm256_fmadd_ps(a1, b1, acc31); 408 acc32 = _mm256_fmadd_ps(a2, b2, acc32); 409 acc33 = _mm256_fmadd_ps(a3, b3, acc33); 410 } 411 for (; k < k_end; k++) { 412 size_t k8 = k / 8, dk = k % 8; 413 __m256 a_bcast = _mm256_set1_ps(A[rn0 + k]); 414 __m256 b_val = _mm256_load_ps(&B_packed[(j8idx * n8 + k8) * 64 + dk * 8]); 415 acc00 = _mm256_fmadd_ps(a_bcast, b_val, acc00); 416 a_bcast = _mm256_set1_ps(A[rn1 + k]); 417 acc10 = _mm256_fmadd_ps(a_bcast, b_val, acc10); 418 a_bcast = _mm256_set1_ps(A[rn2 + k]); 419 acc20 = _mm256_fmadd_ps(a_bcast, b_val, acc20); 420 a_bcast = _mm256_set1_ps(A[rn3 + k]); 421 acc30 = _mm256_fmadd_ps(a_bcast, b_val, acc30); 422 } 423 } 424 425 acc00 = _mm256_add_ps(acc00, acc01); 426 acc02 = _mm256_add_ps(acc02, acc03); 427 acc00 = _mm256_add_ps(acc00, acc02); 428 acc10 = _mm256_add_ps(acc10, acc11); 429 acc12 = _mm256_add_ps(acc12, acc13); 430 acc10 = _mm256_add_ps(acc10, acc12); 431 acc20 = _mm256_add_ps(acc20, acc21); 432 acc22 = _mm256_add_ps(acc22, acc23); 433 acc20 = _mm256_add_ps(acc20, acc22); 434 acc30 = _mm256_add_ps(acc30, acc31); 435 acc32 = _mm256_add_ps(acc32, acc33); 436 acc30 = _mm256_add_ps(acc30, acc32); 437 438 float tmp[8] __attribute__((aligned(32))); 439 _mm256_store_ps(tmp, acc00); 440 for (size_t dj = 0; dj < 8; dj++) acc[li0 * tj + lj + dj] += tmp[dj]; 441 _mm256_store_ps(tmp, acc10); 442 for (size_t dj = 0; dj < 8; dj++) acc[li1 * tj + lj + dj] += tmp[dj]; 443 _mm256_store_ps(tmp, acc20); 444 for (size_t dj = 0; dj < 8; dj++) acc[li2 * tj + lj + dj] += tmp[dj]; 445 _mm256_store_ps(tmp, acc30); 446 for (size_t dj = 0; dj < 8; dj++) acc[li3 * tj + lj + dj] += tmp[dj]; 447 } 448 449 for (size_t j = jj + (tj / 8) * 8; j < j_end; j++) { 450 size_t lj = j - jj; 451 double sum0 = 0.0; 452 double sum1 = 0.0; 453 double sum2 = 0.0; 454 double sum3 = 0.0; 455 for (size_t k = 0; k < n; k++) { 456 sum0 += (double)A[rn0 + k] * (double)B[k * p + j]; 457 sum1 += (double)A[rn1 + k] * (double)B[k * p + j]; 458 sum2 += (double)A[rn2 + k] * (double)B[k * p + j]; 459 sum3 += (double)A[rn3 + k] * (double)B[k * p + j]; 460 } 461 acc[li0 * tj + lj] += (float)sum0; 462 acc[li1 * tj + lj] += (float)sum1; 463 acc[li2 * tj + lj] += (float)sum2; 464 acc[li3 * tj + lj] += (float)sum3; 465 } 466 } 467 468 for (size_t i = ii + (i_end - ii) / 4 * 4; i < i_end; i++) { 469 size_t li = i - ii; 470 size_t rn = i * n; 471 for (size_t j = jj; j + 8 <= j_end; j += 8) { 472 size_t lj = j - jj; 473 size_t j8idx = j / 8; 474 __m256 acc0 = _mm256_setzero_ps(); 475 __m256 acc1 = _mm256_setzero_ps(); 476 __m256 acc2 = _mm256_setzero_ps(); 477 __m256 acc3 = _mm256_setzero_ps(); 478 479 for (size_t kk = 0; kk < n; kk += kb) { 480 size_t k_end = (kk + kb < n) ? kk + kb : n; 481 size_t k_end8 = kk + (k_end - kk) / 8 * 8; 482 size_t k = kk; 483 for (; k + 4 <= k_end8; k += 4) { 484 size_t k8_0 = k / 8, dk_0 = k % 8; 485 size_t k8_1 = (k + 1) / 8, dk_1 = (k + 1) % 8; 486 size_t k8_2 = (k + 2) / 8, dk_2 = (k + 2) % 8; 487 size_t k8_3 = (k + 3) / 8, dk_3 = (k + 3) % 8; 488 __m256 a0 = _mm256_set1_ps(A[rn + k]); 489 __m256 a1 = _mm256_set1_ps(A[rn + k + 1]); 490 __m256 a2 = _mm256_set1_ps(A[rn + k + 2]); 491 __m256 a3 = _mm256_set1_ps(A[rn + k + 3]); 492 __m256 b0 = _mm256_load_ps(&B_packed[(j8idx * n8 + k8_0) * 64 + dk_0 * 8]); 493 __m256 b1 = _mm256_load_ps(&B_packed[(j8idx * n8 + k8_1) * 64 + dk_1 * 8]); 494 __m256 b2 = _mm256_load_ps(&B_packed[(j8idx * n8 + k8_2) * 64 + dk_2 * 8]); 495 __m256 b3 = _mm256_load_ps(&B_packed[(j8idx * n8 + k8_3) * 64 + dk_3 * 8]); 496 acc0 = _mm256_fmadd_ps(a0, b0, acc0); 497 acc1 = _mm256_fmadd_ps(a1, b1, acc1); 498 acc2 = _mm256_fmadd_ps(a2, b2, acc2); 499 acc3 = _mm256_fmadd_ps(a3, b3, acc3); 500 } 501 for (; k < k_end; k++) { 502 size_t k8 = k / 8, dk = k % 8; 503 __m256 a_bcast = _mm256_set1_ps(A[rn + k]); 504 __m256 b_val = _mm256_load_ps(&B_packed[(j8idx * n8 + k8) * 64 + dk * 8]); 505 acc0 = _mm256_fmadd_ps(a_bcast, b_val, acc0); 506 } 507 } 508 509 acc0 = _mm256_add_ps(acc0, acc1); 510 acc2 = _mm256_add_ps(acc2, acc3); 511 acc0 = _mm256_add_ps(acc0, acc2); 512 513 float tmp[8] __attribute__((aligned(32))); 514 _mm256_store_ps(tmp, acc0); 515 for (size_t dj = 0; dj < 8; dj++) acc[li * tj + lj + dj] += tmp[dj]; 516 } 517 for (size_t j = jj + (tj / 8) * 8; j < j_end; j++) { 518 size_t lj = j - jj; 519 for (size_t k = 0; k < n; k++) { 520 acc[li * tj + lj] += (double)A[rn + k] * (double)B[k * p + j]; 521 } 522 } 523 } 524 525 for (size_t i = ii; i < i_end; i++) { 526 size_t li = i - ii; 527 for (size_t j = jj; j < j_end; j++) { 528 size_t lj = j - jj; 529 C[i * p + j] = acc[li * tj + lj] * inv_scale; 530 } 531 } 532 } 533 } 534 535 for (size_t i = 0; i < m; i++) { 536 for (size_t j = p8 * 8; j < p; j++) { 537 double sum = 0.0; 538 for (size_t k = 0; k < n; k++) { 539 sum += (double)A[i * n + k] * (double)B[k * p + j]; 540 } 541 C[i * p + j] = (float)(sum * inv_scale); 542 } 543 } 544 545 free(B_packed); 546 return 0; 547 } 548 #endif 549 550 #ifdef __AVX512F__ 551 static void pack_b_f32_512(size_t n, size_t p, const float *B, float *B_packed) { 552 size_t n16 = n / 16; 553 size_t p16 = p / 16; 554 for (size_t j16 = 0; j16 < p16; j16++) { 555 for (size_t k16 = 0; k16 < n16; k16++) { 556 float *dst = &B_packed[(j16 * n16 + k16) * 256]; 557 for (size_t dk = 0; dk < 16; dk++) { 558 size_t k = k16 * 16 + dk; 559 for (size_t dj = 0; dj < 16; dj++) { 560 size_t j = j16 * 16 + dj; 561 dst[dk * 16 + dj] = B[k * p + j]; 562 } 563 } 564 } 565 } 566 } 567 568 int matmul_avx512_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const float *B, float *C, double scale) { 569 const size_t ib = 64; 570 const size_t jb = 64; 571 const size_t kb = 16; 572 573 size_t n16 = n / 16; 574 size_t p16 = p / 16; 575 576 float *B_packed; 577 if (posix_memalign((void **)&B_packed, 64, p16 * n16 * 256 * sizeof(float)) != 0) return -1; 578 pack_b_f32_512(n, p, B, B_packed); 579 580 float inv_scale = (scale > 1.0) ? 1.0f / (float)scale : 1.0f; 581 582 #pragma omp parallel for schedule(static) 583 for (size_t ii = 0; ii < m; ii += ib) { 584 size_t i_end = (ii + ib < m) ? ii + ib : m; 585 for (size_t jj = 0; jj < p16 * 16; jj += jb) { 586 size_t j_end = (jj + jb < p16 * 16) ? jj + jb : p16 * 16; 587 size_t ti = i_end - ii; 588 size_t tj = j_end - jj; 589 float acc[64 * 64]; 590 memset(acc, 0, ti * tj * sizeof(float)); 591 592 for (size_t i4 = ii; i4 + 4 <= i_end; i4 += 4) { 593 size_t li0 = i4 - ii; 594 size_t li1 = li0 + 1; 595 size_t li2 = li0 + 2; 596 size_t li3 = li0 + 3; 597 size_t rn0 = i4 * n; 598 size_t rn1 = (i4 + 1) * n; 599 size_t rn2 = (i4 + 2) * n; 600 size_t rn3 = (i4 + 3) * n; 601 602 for (size_t j = jj; j + 16 <= j_end; j += 16) { 603 size_t lj = j - jj; 604 size_t j16idx = j / 16; 605 __m512 acc00 = _mm512_setzero_ps(); 606 __m512 acc01 = _mm512_setzero_ps(); 607 __m512 acc02 = _mm512_setzero_ps(); 608 __m512 acc03 = _mm512_setzero_ps(); 609 __m512 acc10 = _mm512_setzero_ps(); 610 __m512 acc11 = _mm512_setzero_ps(); 611 __m512 acc12 = _mm512_setzero_ps(); 612 __m512 acc13 = _mm512_setzero_ps(); 613 __m512 acc20 = _mm512_setzero_ps(); 614 __m512 acc21 = _mm512_setzero_ps(); 615 __m512 acc22 = _mm512_setzero_ps(); 616 __m512 acc23 = _mm512_setzero_ps(); 617 __m512 acc30 = _mm512_setzero_ps(); 618 __m512 acc31 = _mm512_setzero_ps(); 619 __m512 acc32 = _mm512_setzero_ps(); 620 __m512 acc33 = _mm512_setzero_ps(); 621 622 for (size_t kk = 0; kk < n; kk += kb) { 623 size_t k_end = (kk + kb < n) ? kk + kb : n; 624 size_t k_end16 = kk + (k_end - kk) / 16 * 16; 625 size_t k = kk; 626 for (; k + 4 <= k_end16; k += 4) { 627 size_t k16_0 = k / 16, dk_0 = k % 16; 628 size_t k16_1 = (k + 1) / 16, dk_1 = (k + 1) % 16; 629 size_t k16_2 = (k + 2) / 16, dk_2 = (k + 2) % 16; 630 size_t k16_3 = (k + 3) / 16, dk_3 = (k + 3) % 16; 631 __m512 a0 = _mm512_set1_ps(A[rn0 + k]); 632 __m512 a1 = _mm512_set1_ps(A[rn0 + k + 1]); 633 __m512 a2 = _mm512_set1_ps(A[rn0 + k + 2]); 634 __m512 a3 = _mm512_set1_ps(A[rn0 + k + 3]); 635 __m512 b0 = _mm512_load_ps(&B_packed[(j16idx * n16 + k16_0) * 256 + dk_0 * 16]); 636 __m512 b1 = _mm512_load_ps(&B_packed[(j16idx * n16 + k16_1) * 256 + dk_1 * 16]); 637 __m512 b2 = _mm512_load_ps(&B_packed[(j16idx * n16 + k16_2) * 256 + dk_2 * 16]); 638 __m512 b3 = _mm512_load_ps(&B_packed[(j16idx * n16 + k16_3) * 256 + dk_3 * 16]); 639 acc00 = _mm512_fmadd_ps(a0, b0, acc00); 640 acc01 = _mm512_fmadd_ps(a1, b1, acc01); 641 acc02 = _mm512_fmadd_ps(a2, b2, acc02); 642 acc03 = _mm512_fmadd_ps(a3, b3, acc03); 643 a0 = _mm512_set1_ps(A[rn1 + k]); 644 a1 = _mm512_set1_ps(A[rn1 + k + 1]); 645 a2 = _mm512_set1_ps(A[rn1 + k + 2]); 646 a3 = _mm512_set1_ps(A[rn1 + k + 3]); 647 acc10 = _mm512_fmadd_ps(a0, b0, acc10); 648 acc11 = _mm512_fmadd_ps(a1, b1, acc11); 649 acc12 = _mm512_fmadd_ps(a2, b2, acc12); 650 acc13 = _mm512_fmadd_ps(a3, b3, acc13); 651 a0 = _mm512_set1_ps(A[rn2 + k]); 652 a1 = _mm512_set1_ps(A[rn2 + k + 1]); 653 a2 = _mm512_set1_ps(A[rn2 + k + 2]); 654 a3 = _mm512_set1_ps(A[rn2 + k + 3]); 655 acc20 = _mm512_fmadd_ps(a0, b0, acc20); 656 acc21 = _mm512_fmadd_ps(a1, b1, acc21); 657 acc22 = _mm512_fmadd_ps(a2, b2, acc22); 658 acc23 = _mm512_fmadd_ps(a3, b3, acc23); 659 a0 = _mm512_set1_ps(A[rn3 + k]); 660 a1 = _mm512_set1_ps(A[rn3 + k + 1]); 661 a2 = _mm512_set1_ps(A[rn3 + k + 2]); 662 a3 = _mm512_set1_ps(A[rn3 + k + 3]); 663 acc30 = _mm512_fmadd_ps(a0, b0, acc30); 664 acc31 = _mm512_fmadd_ps(a1, b1, acc31); 665 acc32 = _mm512_fmadd_ps(a2, b2, acc32); 666 acc33 = _mm512_fmadd_ps(a3, b3, acc33); 667 } 668 for (; k < k_end; k++) { 669 size_t k16 = k / 16, dk = k % 16; 670 __m512 a_bcast = _mm512_set1_ps(A[rn0 + k]); 671 __m512 b_val = _mm512_load_ps(&B_packed[(j16idx * n16 + k16) * 256 + dk * 16]); 672 acc00 = _mm512_fmadd_ps(a_bcast, b_val, acc00); 673 a_bcast = _mm512_set1_ps(A[rn1 + k]); 674 acc10 = _mm512_fmadd_ps(a_bcast, b_val, acc10); 675 a_bcast = _mm512_set1_ps(A[rn2 + k]); 676 acc20 = _mm512_fmadd_ps(a_bcast, b_val, acc20); 677 a_bcast = _mm512_set1_ps(A[rn3 + k]); 678 acc30 = _mm512_fmadd_ps(a_bcast, b_val, acc30); 679 } 680 } 681 682 acc00 = _mm512_add_ps(acc00, acc01); 683 acc02 = _mm512_add_ps(acc02, acc03); 684 acc00 = _mm512_add_ps(acc00, acc02); 685 acc10 = _mm512_add_ps(acc10, acc11); 686 acc12 = _mm512_add_ps(acc12, acc13); 687 acc10 = _mm512_add_ps(acc10, acc12); 688 acc20 = _mm512_add_ps(acc20, acc21); 689 acc22 = _mm512_add_ps(acc22, acc23); 690 acc20 = _mm512_add_ps(acc20, acc22); 691 acc30 = _mm512_add_ps(acc30, acc31); 692 acc32 = _mm512_add_ps(acc32, acc33); 693 acc30 = _mm512_add_ps(acc30, acc32); 694 695 float tmp[16] __attribute__((aligned(64))); 696 _mm512_store_ps(tmp, acc00); 697 for (size_t dj = 0; dj < 16; dj++) acc[li0 * tj + lj + dj] += tmp[dj]; 698 _mm512_store_ps(tmp, acc10); 699 for (size_t dj = 0; dj < 16; dj++) acc[li1 * tj + lj + dj] += tmp[dj]; 700 _mm512_store_ps(tmp, acc20); 701 for (size_t dj = 0; dj < 16; dj++) acc[li2 * tj + lj + dj] += tmp[dj]; 702 _mm512_store_ps(tmp, acc30); 703 for (size_t dj = 0; dj < 16; dj++) acc[li3 * tj + lj + dj] += tmp[dj]; 704 } 705 706 for (size_t j = jj + (tj / 16) * 16; j < j_end; j++) { 707 size_t lj = j - jj; 708 double sum0 = 0.0; 709 double sum1 = 0.0; 710 double sum2 = 0.0; 711 double sum3 = 0.0; 712 for (size_t k = 0; k < n; k++) { 713 sum0 += (double)A[rn0 + k] * (double)B[k * p + j]; 714 sum1 += (double)A[rn1 + k] * (double)B[k * p + j]; 715 sum2 += (double)A[rn2 + k] * (double)B[k * p + j]; 716 sum3 += (double)A[rn3 + k] * (double)B[k * p + j]; 717 } 718 acc[li0 * tj + lj] += (float)sum0; 719 acc[li1 * tj + lj] += (float)sum1; 720 acc[li2 * tj + lj] += (float)sum2; 721 acc[li3 * tj + lj] += (float)sum3; 722 } 723 } 724 725 for (size_t i = ii + (i_end - ii) / 4 * 4; i < i_end; i++) { 726 size_t li = i - ii; 727 size_t rn = i * n; 728 for (size_t j = jj; j + 16 <= j_end; j += 16) { 729 size_t lj = j - jj; 730 size_t j16idx = j / 16; 731 __m512 acc0 = _mm512_setzero_ps(); 732 __m512 acc1 = _mm512_setzero_ps(); 733 __m512 acc2 = _mm512_setzero_ps(); 734 __m512 acc3 = _mm512_setzero_ps(); 735 736 for (size_t kk = 0; kk < n; kk += kb) { 737 size_t k_end = (kk + kb < n) ? kk + kb : n; 738 size_t k_end16 = kk + (k_end - kk) / 16 * 16; 739 size_t k = kk; 740 for (; k + 4 <= k_end16; k += 4) { 741 size_t k16_0 = k / 16, dk_0 = k % 16; 742 size_t k16_1 = (k + 1) / 16, dk_1 = (k + 1) % 16; 743 size_t k16_2 = (k + 2) / 16, dk_2 = (k + 2) % 16; 744 size_t k16_3 = (k + 3) / 16, dk_3 = (k + 3) % 16; 745 __m512 a0 = _mm512_set1_ps(A[rn + k]); 746 __m512 a1 = _mm512_set1_ps(A[rn + k + 1]); 747 __m512 a2 = _mm512_set1_ps(A[rn + k + 2]); 748 __m512 a3 = _mm512_set1_ps(A[rn + k + 3]); 749 __m512 b0 = _mm512_load_ps(&B_packed[(j16idx * n16 + k16_0) * 256 + dk_0 * 16]); 750 __m512 b1 = _mm512_load_ps(&B_packed[(j16idx * n16 + k16_1) * 256 + dk_1 * 16]); 751 __m512 b2 = _mm512_load_ps(&B_packed[(j16idx * n16 + k16_2) * 256 + dk_2 * 16]); 752 __m512 b3 = _mm512_load_ps(&B_packed[(j16idx * n16 + k16_3) * 256 + dk_3 * 16]); 753 acc0 = _mm512_fmadd_ps(a0, b0, acc0); 754 acc1 = _mm512_fmadd_ps(a1, b1, acc1); 755 acc2 = _mm512_fmadd_ps(a2, b2, acc2); 756 acc3 = _mm512_fmadd_ps(a3, b3, acc3); 757 } 758 for (; k < k_end; k++) { 759 size_t k16 = k / 16, dk = k % 16; 760 __m512 a_bcast = _mm512_set1_ps(A[rn + k]); 761 __m512 b_val = _mm512_load_ps(&B_packed[(j16idx * n16 + k16) * 256 + dk * 16]); 762 acc0 = _mm512_fmadd_ps(a_bcast, b_val, acc0); 763 } 764 } 765 766 acc0 = _mm512_add_ps(acc0, acc1); 767 acc2 = _mm512_add_ps(acc2, acc3); 768 acc0 = _mm512_add_ps(acc0, acc2); 769 770 float tmp[16] __attribute__((aligned(64))); 771 _mm512_store_ps(tmp, acc0); 772 for (size_t dj = 0; dj < 16; dj++) acc[li * tj + lj + dj] += tmp[dj]; 773 } 774 for (size_t j = jj + (tj / 16) * 16; j < j_end; j++) { 775 size_t lj = j - jj; 776 for (size_t k = 0; k < n; k++) { 777 acc[li * tj + lj] += (double)A[rn + k] * (double)B[k * p + j]; 778 } 779 } 780 } 781 782 for (size_t i = ii; i < i_end; i++) { 783 size_t li = i - ii; 784 for (size_t j = jj; j < j_end; j++) { 785 size_t lj = j - jj; 786 C[i * p + j] = acc[li * tj + lj] * inv_scale; 787 } 788 } 789 } 790 } 791 792 for (size_t i = 0; i < m; i++) { 793 for (size_t j = p16 * 16; j < p; j++) { 794 double sum = 0.0; 795 for (size_t k = 0; k < n; k++) { 796 sum += (double)A[i * n + k] * (double)B[k * p + j]; 797 } 798 C[i * p + j] = (float)(sum * inv_scale); 799 } 800 } 801 802 free(B_packed); 803 return 0; 804 } 805 #endif 806 807 static int _matmul_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const float *B, float *C, double scale) { 808 static int initialized = 0; 809 if (!initialized) { 810 matmul_feature_t feat = matmul_get_feature(); 811 #ifdef __AVX512F__ 812 if (feat & MATMUL_FLAG_AVX512) 813 matmul_f32_f32_f32 = matmul_avx512_f32_f32_f32; 814 else 815 #endif 816 #ifdef __AVX2__ 817 if (feat & MATMUL_FLAG_AVX2) 818 matmul_f32_f32_f32 = matmul_avx2_f32_f32_f32; 819 else 820 #endif 821 matmul_f32_f32_f32 = matmul_scalar_f32_f32_f32; 822 initialized = 1; 823 } 824 return matmul_f32_f32_f32(m, n, p, A, B, C, scale); 825 } 826 827 int (*matmul_f32_f32_f32)(size_t, size_t, size_t, const float *, const float *, float *, double) = _matmul_f32_f32_f32; 828 829 /* ========================================================================== */ 830 /* f64_f64_f64 implementations */ 831 /* ========================================================================== */ 832 833 int matmul_scalar_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, const double *B, double *C, double scale) { 834 const size_t ib = 64; 835 const size_t jb = 64; 836 const size_t kb = 8; 837 838 #pragma omp parallel for schedule(static) 839 for (size_t ii = 0; ii < m; ii += ib) { 840 size_t i_end = (ii + ib < m) ? ii + ib : m; 841 for (size_t jj = 0; jj < p; jj += jb) { 842 size_t j_end = (jj + jb < p) ? jj + jb : p; 843 size_t ti = i_end - ii; 844 size_t tj = j_end - jj; 845 double acc[64 * 64]; 846 memset(acc, 0, ti * tj * sizeof(double)); 847 848 for (size_t kk = 0; kk < n; kk += kb) { 849 size_t k_end = (kk + kb < n) ? kk + kb : n; 850 for (size_t i = ii; i < i_end; i++) { 851 size_t li = i - ii; 852 for (size_t j = jj; j < j_end; j++) { 853 size_t lj = j - jj; 854 double sum = 0.0; 855 for (size_t k = kk; k < k_end; k++) { 856 sum += A[i * n + k] * B[k * p + j]; 857 } 858 acc[li * tj + lj] += sum; 859 } 860 } 861 } 862 863 for (size_t i = ii; i < i_end; i++) { 864 size_t li = i - ii; 865 for (size_t j = jj; j < j_end; j++) { 866 size_t lj = j - jj; 867 double v = acc[li * tj + lj]; 868 if (scale > 1.0) v /= scale; 869 C[i * p + j] = v; 870 } 871 } 872 } 873 } 874 return 0; 875 } 876 877 #ifdef __AVX2__ 878 static void pack_b_f64(size_t n, size_t p, const double *B, double *B_packed) { 879 size_t n4 = n / 4; 880 size_t p4 = p / 4; 881 for (size_t j4 = 0; j4 < p4; j4++) { 882 for (size_t k4 = 0; k4 < n4; k4++) { 883 double *dst = &B_packed[(j4 * n4 + k4) * 16]; 884 for (size_t dk = 0; dk < 4; dk++) { 885 size_t k = k4 * 4 + dk; 886 for (size_t dj = 0; dj < 4; dj++) { 887 size_t j = j4 * 4 + dj; 888 dst[dk * 4 + dj] = B[k * p + j]; 889 } 890 } 891 } 892 } 893 } 894 895 int matmul_avx2_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, const double *B, double *C, double scale) { 896 const size_t ib = 64; 897 const size_t jb = 64; 898 const size_t kb = 8; 899 900 size_t n4 = n / 4; 901 size_t p4 = p / 4; 902 903 double *B_packed; 904 if (posix_memalign((void **)&B_packed, 64, p4 * n4 * 16 * sizeof(double)) != 0) return -1; 905 pack_b_f64(n, p, B, B_packed); 906 907 double inv_scale = (scale > 1.0) ? 1.0 / scale : 1.0; 908 909 #pragma omp parallel for schedule(static) 910 for (size_t ii = 0; ii < m; ii += ib) { 911 size_t i_end = (ii + ib < m) ? ii + ib : m; 912 for (size_t jj = 0; jj < p4 * 4; jj += jb) { 913 size_t j_end = (jj + jb < p4 * 4) ? jj + jb : p4 * 4; 914 size_t ti = i_end - ii; 915 size_t tj = j_end - jj; 916 double acc[64 * 64]; 917 memset(acc, 0, ti * tj * sizeof(double)); 918 919 for (size_t i4 = ii; i4 + 4 <= i_end; i4 += 4) { 920 size_t li0 = i4 - ii; 921 size_t li1 = li0 + 1; 922 size_t li2 = li0 + 2; 923 size_t li3 = li0 + 3; 924 size_t rn0 = i4 * n; 925 size_t rn1 = (i4 + 1) * n; 926 size_t rn2 = (i4 + 2) * n; 927 size_t rn3 = (i4 + 3) * n; 928 929 for (size_t j = jj; j + 4 <= j_end; j += 4) { 930 size_t lj = j - jj; 931 size_t j4idx = j / 4; 932 __m256d acc00 = _mm256_setzero_pd(); 933 __m256d acc01 = _mm256_setzero_pd(); 934 __m256d acc02 = _mm256_setzero_pd(); 935 __m256d acc03 = _mm256_setzero_pd(); 936 __m256d acc10 = _mm256_setzero_pd(); 937 __m256d acc11 = _mm256_setzero_pd(); 938 __m256d acc12 = _mm256_setzero_pd(); 939 __m256d acc13 = _mm256_setzero_pd(); 940 __m256d acc20 = _mm256_setzero_pd(); 941 __m256d acc21 = _mm256_setzero_pd(); 942 __m256d acc22 = _mm256_setzero_pd(); 943 __m256d acc23 = _mm256_setzero_pd(); 944 __m256d acc30 = _mm256_setzero_pd(); 945 __m256d acc31 = _mm256_setzero_pd(); 946 __m256d acc32 = _mm256_setzero_pd(); 947 __m256d acc33 = _mm256_setzero_pd(); 948 949 for (size_t kk = 0; kk < n; kk += kb) { 950 size_t k_end = (kk + kb < n) ? kk + kb : n; 951 size_t k_end4 = kk + (k_end - kk) / 4 * 4; 952 size_t k = kk; 953 for (; k + 4 <= k_end4; k += 4) { 954 size_t k4_0 = k / 4, dk_0 = k % 4; 955 size_t k4_1 = (k + 1) / 4, dk_1 = (k + 1) % 4; 956 size_t k4_2 = (k + 2) / 4, dk_2 = (k + 2) % 4; 957 size_t k4_3 = (k + 3) / 4, dk_3 = (k + 3) % 4; 958 __m256d a0 = _mm256_set1_pd(A[rn0 + k]); 959 __m256d a1 = _mm256_set1_pd(A[rn0 + k + 1]); 960 __m256d a2 = _mm256_set1_pd(A[rn0 + k + 2]); 961 __m256d a3 = _mm256_set1_pd(A[rn0 + k + 3]); 962 __m256d b0 = _mm256_load_pd(&B_packed[(j4idx * n4 + k4_0) * 16 + dk_0 * 4]); 963 __m256d b1 = _mm256_load_pd(&B_packed[(j4idx * n4 + k4_1) * 16 + dk_1 * 4]); 964 __m256d b2 = _mm256_load_pd(&B_packed[(j4idx * n4 + k4_2) * 16 + dk_2 * 4]); 965 __m256d b3 = _mm256_load_pd(&B_packed[(j4idx * n4 + k4_3) * 16 + dk_3 * 4]); 966 acc00 = _mm256_fmadd_pd(a0, b0, acc00); 967 acc01 = _mm256_fmadd_pd(a1, b1, acc01); 968 acc02 = _mm256_fmadd_pd(a2, b2, acc02); 969 acc03 = _mm256_fmadd_pd(a3, b3, acc03); 970 a0 = _mm256_set1_pd(A[rn1 + k]); 971 a1 = _mm256_set1_pd(A[rn1 + k + 1]); 972 a2 = _mm256_set1_pd(A[rn1 + k + 2]); 973 a3 = _mm256_set1_pd(A[rn1 + k + 3]); 974 acc10 = _mm256_fmadd_pd(a0, b0, acc10); 975 acc11 = _mm256_fmadd_pd(a1, b1, acc11); 976 acc12 = _mm256_fmadd_pd(a2, b2, acc12); 977 acc13 = _mm256_fmadd_pd(a3, b3, acc13); 978 a0 = _mm256_set1_pd(A[rn2 + k]); 979 a1 = _mm256_set1_pd(A[rn2 + k + 1]); 980 a2 = _mm256_set1_pd(A[rn2 + k + 2]); 981 a3 = _mm256_set1_pd(A[rn2 + k + 3]); 982 acc20 = _mm256_fmadd_pd(a0, b0, acc20); 983 acc21 = _mm256_fmadd_pd(a1, b1, acc21); 984 acc22 = _mm256_fmadd_pd(a2, b2, acc22); 985 acc23 = _mm256_fmadd_pd(a3, b3, acc23); 986 a0 = _mm256_set1_pd(A[rn3 + k]); 987 a1 = _mm256_set1_pd(A[rn3 + k + 1]); 988 a2 = _mm256_set1_pd(A[rn3 + k + 2]); 989 a3 = _mm256_set1_pd(A[rn3 + k + 3]); 990 acc30 = _mm256_fmadd_pd(a0, b0, acc30); 991 acc31 = _mm256_fmadd_pd(a1, b1, acc31); 992 acc32 = _mm256_fmadd_pd(a2, b2, acc32); 993 acc33 = _mm256_fmadd_pd(a3, b3, acc33); 994 } 995 for (; k < k_end; k++) { 996 size_t k4 = k / 4, dk = k % 4; 997 __m256d a_bcast = _mm256_set1_pd(A[rn0 + k]); 998 __m256d b_val = _mm256_load_pd(&B_packed[(j4idx * n4 + k4) * 16 + dk * 4]); 999 acc00 = _mm256_fmadd_pd(a_bcast, b_val, acc00); 1000 a_bcast = _mm256_set1_pd(A[rn1 + k]); 1001 acc10 = _mm256_fmadd_pd(a_bcast, b_val, acc10); 1002 a_bcast = _mm256_set1_pd(A[rn2 + k]); 1003 acc20 = _mm256_fmadd_pd(a_bcast, b_val, acc20); 1004 a_bcast = _mm256_set1_pd(A[rn3 + k]); 1005 acc30 = _mm256_fmadd_pd(a_bcast, b_val, acc30); 1006 } 1007 } 1008 1009 acc00 = _mm256_add_pd(acc00, acc01); 1010 acc02 = _mm256_add_pd(acc02, acc03); 1011 acc00 = _mm256_add_pd(acc00, acc02); 1012 acc10 = _mm256_add_pd(acc10, acc11); 1013 acc12 = _mm256_add_pd(acc12, acc13); 1014 acc10 = _mm256_add_pd(acc10, acc12); 1015 acc20 = _mm256_add_pd(acc20, acc21); 1016 acc22 = _mm256_add_pd(acc22, acc23); 1017 acc20 = _mm256_add_pd(acc20, acc22); 1018 acc30 = _mm256_add_pd(acc30, acc31); 1019 acc32 = _mm256_add_pd(acc32, acc33); 1020 acc30 = _mm256_add_pd(acc30, acc32); 1021 1022 double tmp[4] __attribute__((aligned(32))); 1023 _mm256_store_pd(tmp, acc00); 1024 for (size_t dj = 0; dj < 4; dj++) acc[li0 * tj + lj + dj] += tmp[dj]; 1025 _mm256_store_pd(tmp, acc10); 1026 for (size_t dj = 0; dj < 4; dj++) acc[li1 * tj + lj + dj] += tmp[dj]; 1027 _mm256_store_pd(tmp, acc20); 1028 for (size_t dj = 0; dj < 4; dj++) acc[li2 * tj + lj + dj] += tmp[dj]; 1029 _mm256_store_pd(tmp, acc30); 1030 for (size_t dj = 0; dj < 4; dj++) acc[li3 * tj + lj + dj] += tmp[dj]; 1031 } 1032 1033 for (size_t j = jj + (tj / 4) * 4; j < j_end; j++) { 1034 size_t lj = j - jj; 1035 double sum0 = 0.0; 1036 double sum1 = 0.0; 1037 double sum2 = 0.0; 1038 double sum3 = 0.0; 1039 for (size_t k = 0; k < n; k++) { 1040 sum0 += A[rn0 + k] * B[k * p + j]; 1041 sum1 += A[rn1 + k] * B[k * p + j]; 1042 sum2 += A[rn2 + k] * B[k * p + j]; 1043 sum3 += A[rn3 + k] * B[k * p + j]; 1044 } 1045 acc[li0 * tj + lj] += sum0; 1046 acc[li1 * tj + lj] += sum1; 1047 acc[li2 * tj + lj] += sum2; 1048 acc[li3 * tj + lj] += sum3; 1049 } 1050 } 1051 1052 for (size_t i = ii + (i_end - ii) / 4 * 4; i < i_end; i++) { 1053 size_t li = i - ii; 1054 size_t rn = i * n; 1055 for (size_t j = jj; j + 4 <= j_end; j += 4) { 1056 size_t lj = j - jj; 1057 size_t j4idx = j / 4; 1058 __m256d acc0 = _mm256_setzero_pd(); 1059 __m256d acc1 = _mm256_setzero_pd(); 1060 __m256d acc2 = _mm256_setzero_pd(); 1061 __m256d acc3 = _mm256_setzero_pd(); 1062 1063 for (size_t kk = 0; kk < n; kk += kb) { 1064 size_t k_end = (kk + kb < n) ? kk + kb : n; 1065 size_t k_end4 = kk + (k_end - kk) / 4 * 4; 1066 size_t k = kk; 1067 for (; k + 4 <= k_end4; k += 4) { 1068 size_t k4_0 = k / 4, dk_0 = k % 4; 1069 size_t k4_1 = (k + 1) / 4, dk_1 = (k + 1) % 4; 1070 size_t k4_2 = (k + 2) / 4, dk_2 = (k + 2) % 4; 1071 size_t k4_3 = (k + 3) / 4, dk_3 = (k + 3) % 4; 1072 __m256d a0 = _mm256_set1_pd(A[rn + k]); 1073 __m256d a1 = _mm256_set1_pd(A[rn + k + 1]); 1074 __m256d a2 = _mm256_set1_pd(A[rn + k + 2]); 1075 __m256d a3 = _mm256_set1_pd(A[rn + k + 3]); 1076 __m256d b0 = _mm256_load_pd(&B_packed[(j4idx * n4 + k4_0) * 16 + dk_0 * 4]); 1077 __m256d b1 = _mm256_load_pd(&B_packed[(j4idx * n4 + k4_1) * 16 + dk_1 * 4]); 1078 __m256d b2 = _mm256_load_pd(&B_packed[(j4idx * n4 + k4_2) * 16 + dk_2 * 4]); 1079 __m256d b3 = _mm256_load_pd(&B_packed[(j4idx * n4 + k4_3) * 16 + dk_3 * 4]); 1080 acc0 = _mm256_fmadd_pd(a0, b0, acc0); 1081 acc1 = _mm256_fmadd_pd(a1, b1, acc1); 1082 acc2 = _mm256_fmadd_pd(a2, b2, acc2); 1083 acc3 = _mm256_fmadd_pd(a3, b3, acc3); 1084 } 1085 for (; k < k_end; k++) { 1086 size_t k4 = k / 4, dk = k % 4; 1087 __m256d a_bcast = _mm256_set1_pd(A[rn + k]); 1088 __m256d b_val = _mm256_load_pd(&B_packed[(j4idx * n4 + k4) * 16 + dk * 4]); 1089 acc0 = _mm256_fmadd_pd(a_bcast, b_val, acc0); 1090 } 1091 } 1092 1093 acc0 = _mm256_add_pd(acc0, acc1); 1094 acc2 = _mm256_add_pd(acc2, acc3); 1095 acc0 = _mm256_add_pd(acc0, acc2); 1096 1097 double tmp[4] __attribute__((aligned(32))); 1098 _mm256_store_pd(tmp, acc0); 1099 for (size_t dj = 0; dj < 4; dj++) acc[li * tj + lj + dj] += tmp[dj]; 1100 } 1101 for (size_t j = jj + (tj / 4) * 4; j < j_end; j++) { 1102 size_t lj = j - jj; 1103 for (size_t k = 0; k < n; k++) { 1104 acc[li * tj + lj] += A[rn + k] * B[k * p + j]; 1105 } 1106 } 1107 } 1108 1109 for (size_t i = ii; i < i_end; i++) { 1110 size_t li = i - ii; 1111 for (size_t j = jj; j < j_end; j++) { 1112 size_t lj = j - jj; 1113 C[i * p + j] = acc[li * tj + lj] * inv_scale; 1114 } 1115 } 1116 } 1117 } 1118 1119 for (size_t i = 0; i < m; i++) { 1120 for (size_t j = p4 * 4; j < p; j++) { 1121 double sum = 0.0; 1122 for (size_t k = 0; k < n; k++) { 1123 sum += A[i * n + k] * B[k * p + j]; 1124 } 1125 C[i * p + j] = sum * inv_scale; 1126 } 1127 } 1128 1129 free(B_packed); 1130 return 0; 1131 } 1132 #endif 1133 1134 #ifdef __AVX512F__ 1135 static void pack_b_f64_512(size_t n, size_t p, const double *B, double *B_packed) { 1136 size_t n8 = n / 8; 1137 size_t p8 = p / 8; 1138 for (size_t j8 = 0; j8 < p8; j8++) { 1139 for (size_t k8 = 0; k8 < n8; k8++) { 1140 double *dst = &B_packed[(j8 * n8 + k8) * 64]; 1141 for (size_t dk = 0; dk < 8; dk++) { 1142 size_t k = k8 * 8 + dk; 1143 for (size_t dj = 0; dj < 8; dj++) { 1144 size_t j = j8 * 8 + dj; 1145 dst[dk * 8 + dj] = B[k * p + j]; 1146 } 1147 } 1148 } 1149 } 1150 } 1151 1152 int matmul_avx512_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, const double *B, double *C, double scale) { 1153 const size_t ib = 64; 1154 const size_t jb = 64; 1155 const size_t kb = 8; 1156 1157 size_t n8 = n / 8; 1158 size_t p8 = p / 8; 1159 1160 double *B_packed; 1161 if (posix_memalign((void **)&B_packed, 64, p8 * n8 * 64 * sizeof(double)) != 0) return -1; 1162 pack_b_f64_512(n, p, B, B_packed); 1163 1164 double inv_scale = (scale > 1.0) ? 1.0 / scale : 1.0; 1165 1166 #pragma omp parallel for schedule(static) 1167 for (size_t ii = 0; ii < m; ii += ib) { 1168 size_t i_end = (ii + ib < m) ? ii + ib : m; 1169 for (size_t jj = 0; jj < p8 * 8; jj += jb) { 1170 size_t j_end = (jj + jb < p8 * 8) ? jj + jb : p8 * 8; 1171 size_t ti = i_end - ii; 1172 size_t tj = j_end - jj; 1173 double acc[64 * 64]; 1174 memset(acc, 0, ti * tj * sizeof(double)); 1175 1176 for (size_t i4 = ii; i4 + 4 <= i_end; i4 += 4) { 1177 size_t li0 = i4 - ii; 1178 size_t li1 = li0 + 1; 1179 size_t li2 = li0 + 2; 1180 size_t li3 = li0 + 3; 1181 size_t rn0 = i4 * n; 1182 size_t rn1 = (i4 + 1) * n; 1183 size_t rn2 = (i4 + 2) * n; 1184 size_t rn3 = (i4 + 3) * n; 1185 1186 for (size_t j = jj; j + 8 <= j_end; j += 8) { 1187 size_t lj = j - jj; 1188 size_t j8idx = j / 8; 1189 __m512d acc00 = _mm512_setzero_pd(); 1190 __m512d acc01 = _mm512_setzero_pd(); 1191 __m512d acc02 = _mm512_setzero_pd(); 1192 __m512d acc03 = _mm512_setzero_pd(); 1193 __m512d acc10 = _mm512_setzero_pd(); 1194 __m512d acc11 = _mm512_setzero_pd(); 1195 __m512d acc12 = _mm512_setzero_pd(); 1196 __m512d acc13 = _mm512_setzero_pd(); 1197 __m512d acc20 = _mm512_setzero_pd(); 1198 __m512d acc21 = _mm512_setzero_pd(); 1199 __m512d acc22 = _mm512_setzero_pd(); 1200 __m512d acc23 = _mm512_setzero_pd(); 1201 __m512d acc30 = _mm512_setzero_pd(); 1202 __m512d acc31 = _mm512_setzero_pd(); 1203 __m512d acc32 = _mm512_setzero_pd(); 1204 __m512d acc33 = _mm512_setzero_pd(); 1205 1206 for (size_t kk = 0; kk < n; kk += kb) { 1207 size_t k_end = (kk + kb < n) ? kk + kb : n; 1208 size_t k_end8 = kk + (k_end - kk) / 8 * 8; 1209 size_t k = kk; 1210 for (; k + 4 <= k_end8; k += 4) { 1211 size_t k8_0 = k / 8, dk_0 = k % 8; 1212 size_t k8_1 = (k + 1) / 8, dk_1 = (k + 1) % 8; 1213 size_t k8_2 = (k + 2) / 8, dk_2 = (k + 2) % 8; 1214 size_t k8_3 = (k + 3) / 8, dk_3 = (k + 3) % 8; 1215 __m512d a0 = _mm512_set1_pd(A[rn0 + k]); 1216 __m512d a1 = _mm512_set1_pd(A[rn0 + k + 1]); 1217 __m512d a2 = _mm512_set1_pd(A[rn0 + k + 2]); 1218 __m512d a3 = _mm512_set1_pd(A[rn0 + k + 3]); 1219 __m512d b0 = _mm512_load_pd(&B_packed[(j8idx * n8 + k8_0) * 64 + dk_0 * 8]); 1220 __m512d b1 = _mm512_load_pd(&B_packed[(j8idx * n8 + k8_1) * 64 + dk_1 * 8]); 1221 __m512d b2 = _mm512_load_pd(&B_packed[(j8idx * n8 + k8_2) * 64 + dk_2 * 8]); 1222 __m512d b3 = _mm512_load_pd(&B_packed[(j8idx * n8 + k8_3) * 64 + dk_3 * 8]); 1223 acc00 = _mm512_fmadd_pd(a0, b0, acc00); 1224 acc01 = _mm512_fmadd_pd(a1, b1, acc01); 1225 acc02 = _mm512_fmadd_pd(a2, b2, acc02); 1226 acc03 = _mm512_fmadd_pd(a3, b3, acc03); 1227 a0 = _mm512_set1_pd(A[rn1 + k]); 1228 a1 = _mm512_set1_pd(A[rn1 + k + 1]); 1229 a2 = _mm512_set1_pd(A[rn1 + k + 2]); 1230 a3 = _mm512_set1_pd(A[rn1 + k + 3]); 1231 acc10 = _mm512_fmadd_pd(a0, b0, acc10); 1232 acc11 = _mm512_fmadd_pd(a1, b1, acc11); 1233 acc12 = _mm512_fmadd_pd(a2, b2, acc12); 1234 acc13 = _mm512_fmadd_pd(a3, b3, acc13); 1235 a0 = _mm512_set1_pd(A[rn2 + k]); 1236 a1 = _mm512_set1_pd(A[rn2 + k + 1]); 1237 a2 = _mm512_set1_pd(A[rn2 + k + 2]); 1238 a3 = _mm512_set1_pd(A[rn2 + k + 3]); 1239 acc20 = _mm512_fmadd_pd(a0, b0, acc20); 1240 acc21 = _mm512_fmadd_pd(a1, b1, acc21); 1241 acc22 = _mm512_fmadd_pd(a2, b2, acc22); 1242 acc23 = _mm512_fmadd_pd(a3, b3, acc23); 1243 a0 = _mm512_set1_pd(A[rn3 + k]); 1244 a1 = _mm512_set1_pd(A[rn3 + k + 1]); 1245 a2 = _mm512_set1_pd(A[rn3 + k + 2]); 1246 a3 = _mm512_set1_pd(A[rn3 + k + 3]); 1247 acc30 = _mm512_fmadd_pd(a0, b0, acc30); 1248 acc31 = _mm512_fmadd_pd(a1, b1, acc31); 1249 acc32 = _mm512_fmadd_pd(a2, b2, acc32); 1250 acc33 = _mm512_fmadd_pd(a3, b3, acc33); 1251 } 1252 for (; k < k_end; k++) { 1253 size_t k8 = k / 8, dk = k % 8; 1254 __m512d a_bcast = _mm512_set1_pd(A[rn0 + k]); 1255 __m512d b_val = _mm512_load_pd(&B_packed[(j8idx * n8 + k8) * 64 + dk * 8]); 1256 acc00 = _mm512_fmadd_pd(a_bcast, b_val, acc00); 1257 a_bcast = _mm512_set1_pd(A[rn1 + k]); 1258 acc10 = _mm512_fmadd_pd(a_bcast, b_val, acc10); 1259 a_bcast = _mm512_set1_pd(A[rn2 + k]); 1260 acc20 = _mm512_fmadd_pd(a_bcast, b_val, acc20); 1261 a_bcast = _mm512_set1_pd(A[rn3 + k]); 1262 acc30 = _mm512_fmadd_pd(a_bcast, b_val, acc30); 1263 } 1264 } 1265 1266 acc00 = _mm512_add_pd(acc00, acc01); 1267 acc02 = _mm512_add_pd(acc02, acc03); 1268 acc00 = _mm512_add_pd(acc00, acc02); 1269 acc10 = _mm512_add_pd(acc10, acc11); 1270 acc12 = _mm512_add_pd(acc12, acc13); 1271 acc10 = _mm512_add_pd(acc10, acc12); 1272 acc20 = _mm512_add_pd(acc20, acc21); 1273 acc22 = _mm512_add_pd(acc22, acc23); 1274 acc20 = _mm512_add_pd(acc20, acc22); 1275 acc30 = _mm512_add_pd(acc30, acc31); 1276 acc32 = _mm512_add_pd(acc32, acc33); 1277 acc30 = _mm512_add_pd(acc30, acc32); 1278 1279 double tmp[8] __attribute__((aligned(64))); 1280 _mm512_store_pd(tmp, acc00); 1281 for (size_t dj = 0; dj < 8; dj++) acc[li0 * tj + lj + dj] += tmp[dj]; 1282 _mm512_store_pd(tmp, acc10); 1283 for (size_t dj = 0; dj < 8; dj++) acc[li1 * tj + lj + dj] += tmp[dj]; 1284 _mm512_store_pd(tmp, acc20); 1285 for (size_t dj = 0; dj < 8; dj++) acc[li2 * tj + lj + dj] += tmp[dj]; 1286 _mm512_store_pd(tmp, acc30); 1287 for (size_t dj = 0; dj < 8; dj++) acc[li3 * tj + lj + dj] += tmp[dj]; 1288 } 1289 1290 for (size_t j = jj + (tj / 8) * 8; j < j_end; j++) { 1291 size_t lj = j - jj; 1292 double sum0 = 0.0; 1293 double sum1 = 0.0; 1294 double sum2 = 0.0; 1295 double sum3 = 0.0; 1296 for (size_t k = 0; k < n; k++) { 1297 sum0 += A[rn0 + k] * B[k * p + j]; 1298 sum1 += A[rn1 + k] * B[k * p + j]; 1299 sum2 += A[rn2 + k] * B[k * p + j]; 1300 sum3 += A[rn3 + k] * B[k * p + j]; 1301 } 1302 acc[li0 * tj + lj] += sum0; 1303 acc[li1 * tj + lj] += sum1; 1304 acc[li2 * tj + lj] += sum2; 1305 acc[li3 * tj + lj] += sum3; 1306 } 1307 } 1308 1309 for (size_t i = ii + (i_end - ii) / 4 * 4; i < i_end; i++) { 1310 size_t li = i - ii; 1311 size_t rn = i * n; 1312 for (size_t j = jj; j + 8 <= j_end; j += 8) { 1313 size_t lj = j - jj; 1314 size_t j8idx = j / 8; 1315 __m512d acc0 = _mm512_setzero_pd(); 1316 __m512d acc1 = _mm512_setzero_pd(); 1317 __m512d acc2 = _mm512_setzero_pd(); 1318 __m512d acc3 = _mm512_setzero_pd(); 1319 1320 for (size_t kk = 0; kk < n; kk += kb) { 1321 size_t k_end = (kk + kb < n) ? kk + kb : n; 1322 size_t k_end8 = kk + (k_end - kk) / 8 * 8; 1323 size_t k = kk; 1324 for (; k + 4 <= k_end8; k += 4) { 1325 size_t k8_0 = k / 8, dk_0 = k % 8; 1326 size_t k8_1 = (k + 1) / 8, dk_1 = (k + 1) % 8; 1327 size_t k8_2 = (k + 2) / 8, dk_2 = (k + 2) % 8; 1328 size_t k8_3 = (k + 3) / 8, dk_3 = (k + 3) % 8; 1329 __m512d a0 = _mm512_set1_pd(A[rn + k]); 1330 __m512d a1 = _mm512_set1_pd(A[rn + k + 1]); 1331 __m512d a2 = _mm512_set1_pd(A[rn + k + 2]); 1332 __m512d a3 = _mm512_set1_pd(A[rn + k + 3]); 1333 __m512d b0 = _mm512_load_pd(&B_packed[(j8idx * n8 + k8_0) * 64 + dk_0 * 8]); 1334 __m512d b1 = _mm512_load_pd(&B_packed[(j8idx * n8 + k8_1) * 64 + dk_1 * 8]); 1335 __m512d b2 = _mm512_load_pd(&B_packed[(j8idx * n8 + k8_2) * 64 + dk_2 * 8]); 1336 __m512d b3 = _mm512_load_pd(&B_packed[(j8idx * n8 + k8_3) * 64 + dk_3 * 8]); 1337 acc0 = _mm512_fmadd_pd(a0, b0, acc0); 1338 acc1 = _mm512_fmadd_pd(a1, b1, acc1); 1339 acc2 = _mm512_fmadd_pd(a2, b2, acc2); 1340 acc3 = _mm512_fmadd_pd(a3, b3, acc3); 1341 } 1342 for (; k < k_end; k++) { 1343 size_t k8 = k / 8, dk = k % 8; 1344 __m512d a_bcast = _mm512_set1_pd(A[rn + k]); 1345 __m512d b_val = _mm512_load_pd(&B_packed[(j8idx * n8 + k8) * 64 + dk * 8]); 1346 acc0 = _mm512_fmadd_pd(a_bcast, b_val, acc0); 1347 } 1348 } 1349 1350 acc0 = _mm512_add_pd(acc0, acc1); 1351 acc2 = _mm512_add_pd(acc2, acc3); 1352 acc0 = _mm512_add_pd(acc0, acc2); 1353 1354 double tmp[8] __attribute__((aligned(64))); 1355 _mm512_store_pd(tmp, acc0); 1356 for (size_t dj = 0; dj < 8; dj++) acc[li * tj + lj + dj] += tmp[dj]; 1357 } 1358 for (size_t j = jj + (tj / 8) * 8; j < j_end; j++) { 1359 size_t lj = j - jj; 1360 for (size_t k = 0; k < n; k++) { 1361 acc[li * tj + lj] += A[rn + k] * B[k * p + j]; 1362 } 1363 } 1364 } 1365 1366 for (size_t i = ii; i < i_end; i++) { 1367 size_t li = i - ii; 1368 for (size_t j = jj; j < j_end; j++) { 1369 size_t lj = j - jj; 1370 C[i * p + j] = acc[li * tj + lj] * inv_scale; 1371 } 1372 } 1373 } 1374 } 1375 1376 for (size_t i = 0; i < m; i++) { 1377 for (size_t j = p8 * 8; j < p; j++) { 1378 double sum = 0.0; 1379 for (size_t k = 0; k < n; k++) { 1380 sum += A[i * n + k] * B[k * p + j]; 1381 } 1382 C[i * p + j] = sum * inv_scale; 1383 } 1384 } 1385 1386 free(B_packed); 1387 return 0; 1388 } 1389 #endif 1390 1391 static int _matmul_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, const double *B, double *C, 1392 double scale) { 1393 static int initialized = 0; 1394 if (!initialized) { 1395 matmul_feature_t feat = matmul_get_feature(); 1396 #ifdef __AVX512F__ 1397 if (feat & MATMUL_FLAG_AVX512) 1398 matmul_f64_f64_f64 = matmul_avx512_f64_f64_f64; 1399 else 1400 #endif 1401 #ifdef __AVX2__ 1402 if (feat & MATMUL_FLAG_AVX2) 1403 matmul_f64_f64_f64 = matmul_avx2_f64_f64_f64; 1404 else 1405 #endif 1406 matmul_f64_f64_f64 = matmul_scalar_f64_f64_f64; 1407 initialized = 1; 1408 } 1409 return matmul_f64_f64_f64(m, n, p, A, B, C, scale); 1410 } 1411 1412 int (*matmul_f64_f64_f64)(size_t, size_t, size_t, const double *, const double *, double *, 1413 double) = _matmul_f64_f64_f64;