test_matmul.c (38885B)
1 #include <math.h> 2 #include <stdint.h> 3 #include <stdio.h> 4 #include <stdlib.h> 5 #include <string.h> 6 7 #include "../src/matmul.h" 8 #include "nemequ/munit.h" 9 #include "test_matmul_simd.h" 10 11 static void ref_u8_i8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, uint8_t *C, double scale) { 12 for (size_t i = 0; i < m; i++) 13 for (size_t j = 0; j < p; j++) { 14 int sum = 0; 15 for (size_t k = 0; k < n; k++) sum += (int)A[i * n + k] * (int)B[k * p + j]; 16 if (scale > 1.0) sum = (int)(sum / scale); 17 if (sum > 255) sum = 255; 18 if (sum < 0) sum = 0; 19 C[i * p + j] = (uint8_t)sum; 20 } 21 } 22 23 static MunitResult test_u8_i8_u8_small(const char *name, 24 int (*matmul_fn)(size_t, size_t, size_t, const uint8_t *, const int8_t *, 25 uint8_t *, double), 26 double epsilon) { 27 uint8_t A[] = {1, 2, 3, 4, 5, 6}; 28 int8_t B[] = {1, 0, 0, 1, 0, 0}; 29 uint8_t C[4], E[4]; 30 31 ref_u8_i8_u8(2, 3, 2, A, B, E, 0.0); 32 matmul_fn(2, 3, 2, A, B, C, 0.0); 33 34 for (int i = 0; i < 4; i++) { 35 int d = E[i] - C[i]; 36 if (d < 0) d = -d; 37 if (d > (int)epsilon) return MUNIT_FAIL; 38 } 39 return MUNIT_OK; 40 } 41 42 static MunitResult test_u8_i8_u8_medium(const char *name, 43 int (*matmul_fn)(size_t, size_t, size_t, const uint8_t *, const int8_t *, 44 uint8_t *, double), 45 double epsilon) { 46 const size_t m = 64, n = 64, p = 64; 47 uint8_t *A = malloc(m * n); 48 int8_t *B = malloc(n * p); 49 uint8_t *C = malloc(m * p); 50 uint8_t *E = malloc(m * p); 51 if (!A || !B || !C || !E) { 52 free(A); 53 free(B); 54 free(C); 55 free(E); 56 return MUNIT_SKIP; 57 } 58 59 // Deterministic pseudo-random values 60 for (size_t i = 0; i < m * n; i++) A[i] = (uint8_t)((i * 7 + 13) % 251); 61 for (size_t i = 0; i < n * p; i++) B[i] = (int8_t)(((i * 11 + 17) % 211) - 105); 62 memset(C, 0, m * p); 63 memset(E, 0, m * p); 64 65 ref_u8_i8_u8(m, n, p, A, B, E, 0.0); 66 matmul_fn(m, n, p, A, B, C, 0.0); 67 68 for (size_t i = 0; i < m * p; i++) { 69 int d = (int)E[i] - (int)C[i]; 70 if (d < 0) d = -d; 71 if (d > (int)epsilon) { 72 free(A); 73 free(B); 74 free(C); 75 free(E); 76 return MUNIT_FAIL; 77 } 78 } 79 80 free(A); 81 free(B); 82 free(C); 83 free(E); 84 return MUNIT_OK; 85 } 86 87 static MunitResult test_u8_i8_u8(const char *name, 88 int (*matmul_fn)(size_t, size_t, size_t, const uint8_t *, const int8_t *, uint8_t *, 89 double), 90 double epsilon) { 91 return test_u8_i8_u8_small(name, matmul_fn, epsilon); 92 } 93 94 static MunitResult test_scalar_u8_i8_u8(const MunitParameter *params, void *data) { 95 (void)params; 96 (void)data; 97 return test_u8_i8_u8("scalar", matmul_scalar_u8_i8_u8, 0); 98 } 99 100 static MunitResult test_scalar_u8_i8_u8_medium(const MunitParameter *params, void *data) { 101 (void)params; 102 (void)data; 103 return test_u8_i8_u8_medium("scalar", matmul_scalar_u8_i8_u8, 0); 104 } 105 106 #ifdef __AVX512VNNI__ 107 static MunitResult test_avx512vnni_u8_i8_u8(const MunitParameter *params, void *data) { 108 (void)params; 109 (void)data; 110 return test_u8_i8_u8("avx512vnni", matmul_avx512vnni_u8_i8_u8, 0); 111 } 112 113 static MunitResult test_avx512vnni_u8_i8_u8_medium(const MunitParameter *params, void *data) { 114 (void)params; 115 (void)data; 116 return test_u8_i8_u8_medium("avx512vnni", matmul_avx512vnni_u8_i8_u8, 0); 117 } 118 #endif 119 120 static MunitResult test_dispatched_u8_i8_u8(const MunitParameter *params, void *data) { 121 (void)params; 122 (void)data; 123 return test_u8_i8_u8("dispatched", matmul_u8_i8_u8, 0); 124 } 125 126 static MunitResult test_dispatched_u8_i8_u8_medium(const MunitParameter *params, void *data) { 127 (void)params; 128 (void)data; 129 return test_u8_i8_u8_medium("dispatched", matmul_u8_i8_u8, 0); 130 } 131 132 static MunitResult test_u8_i8_u8_scaled_small(const char *name, 133 int (*matmul_fn)(size_t, size_t, size_t, const uint8_t *, const int8_t *, 134 uint8_t *, double), 135 double epsilon) { 136 uint8_t A[] = {8, 16, 24, 32, 40, 48}; 137 int8_t B[] = {2, 0, 0, 2, 0, 0}; 138 uint8_t C[4], E[4]; 139 140 ref_u8_i8_u8(2, 3, 2, A, B, E, 4.0); 141 matmul_fn(2, 3, 2, A, B, C, 4.0); 142 143 for (int i = 0; i < 4; i++) { 144 int d = E[i] - C[i]; 145 if (d < 0) d = -d; 146 if (d > (int)epsilon) return MUNIT_FAIL; 147 } 148 return MUNIT_OK; 149 } 150 151 static MunitResult test_u8_i8_u8_scaled_medium(const char *name, 152 int (*matmul_fn)(size_t, size_t, size_t, const uint8_t *, const int8_t *, 153 uint8_t *, double), 154 double epsilon) { 155 const size_t m = 64, n = 64, p = 64; 156 uint8_t *A = malloc(m * n); 157 int8_t *B = malloc(n * p); 158 uint8_t *C = malloc(m * p); 159 uint8_t *E = malloc(m * p); 160 if (!A || !B || !C || !E) { 161 free(A); 162 free(B); 163 free(C); 164 free(E); 165 return MUNIT_SKIP; 166 } 167 168 for (size_t i = 0; i < m * n; i++) A[i] = (uint8_t)((i * 7 + 13) % 251); 169 for (size_t i = 0; i < n * p; i++) B[i] = (int8_t)(((i * 11 + 17) % 211) - 105); 170 memset(C, 0, m * p); 171 memset(E, 0, m * p); 172 173 ref_u8_i8_u8(m, n, p, A, B, E, 8.0); 174 matmul_fn(m, n, p, A, B, C, 8.0); 175 176 for (size_t i = 0; i < m * p; i++) { 177 int d = (int)E[i] - (int)C[i]; 178 if (d < 0) d = -d; 179 if (d > (int)epsilon) { 180 free(A); 181 free(B); 182 free(C); 183 free(E); 184 return MUNIT_FAIL; 185 } 186 } 187 188 free(A); 189 free(B); 190 free(C); 191 free(E); 192 return MUNIT_OK; 193 } 194 195 static MunitResult test_scalar_u8_i8_u8_scaled(const MunitParameter *params, void *data) { 196 (void)params; 197 (void)data; 198 return test_u8_i8_u8_scaled_small("scalar", matmul_scalar_u8_i8_u8, 0); 199 } 200 201 static MunitResult test_scalar_u8_i8_u8_scaled_medium(const MunitParameter *params, void *data) { 202 (void)params; 203 (void)data; 204 return test_u8_i8_u8_scaled_medium("scalar", matmul_scalar_u8_i8_u8, 0); 205 } 206 207 #ifdef __AVX512VNNI__ 208 static MunitResult test_avx512vnni_u8_i8_u8_scaled(const MunitParameter *params, void *data) { 209 (void)params; 210 (void)data; 211 return test_u8_i8_u8_scaled_small("avx512vnni", matmul_avx512vnni_u8_i8_u8, 0); 212 } 213 214 static MunitResult test_avx512vnni_u8_i8_u8_scaled_medium(const MunitParameter *params, void *data) { 215 (void)params; 216 (void)data; 217 return test_u8_i8_u8_scaled_medium("avx512vnni", matmul_avx512vnni_u8_i8_u8, 0); 218 } 219 #endif 220 221 static MunitResult test_dispatched_u8_i8_u8_scaled(const MunitParameter *params, void *data) { 222 (void)params; 223 (void)data; 224 return test_u8_i8_u8_scaled_small("dispatched", matmul_u8_i8_u8, 0); 225 } 226 227 static MunitResult test_dispatched_u8_i8_u8_scaled_medium(const MunitParameter *params, void *data) { 228 (void)params; 229 (void)data; 230 return test_u8_i8_u8_scaled_medium("dispatched", matmul_u8_i8_u8, 0); 231 } 232 233 /* ========================================================================== */ 234 /* f32_f32_f32 tests */ 235 /* ========================================================================== */ 236 237 static void ref_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const float *B, float *C, double scale) { 238 for (size_t i = 0; i < m; i++) 239 for (size_t j = 0; j < p; j++) { 240 double sum = 0.0; 241 for (size_t k = 0; k < n; k++) sum += (double)A[i * n + k] * (double)B[k * p + j]; 242 if (scale > 1.0) sum /= scale; 243 C[i * p + j] = (float)sum; 244 } 245 } 246 247 static MunitResult test_f32_f32_f32_small(const char *name, 248 int (*matmul_fn)(size_t, size_t, size_t, const float *, const float *, 249 float *, double), 250 double epsilon) { 251 float A[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; 252 float B[] = {1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f}; 253 float C[4], E[4]; 254 255 ref_f32_f32_f32(2, 3, 2, A, B, E, 0.0); 256 matmul_fn(2, 3, 2, A, B, C, 0.0); 257 258 for (int i = 0; i < 4; i++) { 259 double d = (double)E[i] - (double)C[i]; 260 if (d < 0) d = -d; 261 if (d > epsilon) return MUNIT_FAIL; 262 } 263 return MUNIT_OK; 264 } 265 266 static MunitResult test_f32_f32_f32_medium(const char *name, 267 int (*matmul_fn)(size_t, size_t, size_t, const float *, const float *, 268 float *, double), 269 double epsilon) { 270 const size_t m = 64, n = 64, p = 64; 271 float *A = malloc(m * n * sizeof(float)); 272 float *B = malloc(n * p * sizeof(float)); 273 float *C = malloc(m * p * sizeof(float)); 274 float *E = malloc(m * p * sizeof(float)); 275 if (!A || !B || !C || !E) { 276 free(A); 277 free(B); 278 free(C); 279 free(E); 280 return MUNIT_SKIP; 281 } 282 283 for (size_t i = 0; i < m * n; i++) A[i] = (float)((i * 7 + 13) % 251); 284 for (size_t i = 0; i < n * p; i++) B[i] = (float)(((i * 11 + 17) % 211) - 105); 285 memset(C, 0, m * p * sizeof(float)); 286 memset(E, 0, m * p * sizeof(float)); 287 288 ref_f32_f32_f32(m, n, p, A, B, E, 0.0); 289 matmul_fn(m, n, p, A, B, C, 0.0); 290 291 for (size_t i = 0; i < m * p; i++) { 292 double d = (double)E[i] - (double)C[i]; 293 if (d < 0) d = -d; 294 if (d > epsilon) { 295 free(A); 296 free(B); 297 free(C); 298 free(E); 299 return MUNIT_FAIL; 300 } 301 } 302 303 free(A); 304 free(B); 305 free(C); 306 free(E); 307 return MUNIT_OK; 308 } 309 310 static MunitResult test_f32_f32_f32_scaled_small(const char *name, 311 int (*matmul_fn)(size_t, size_t, size_t, const float *, const float *, 312 float *, double), 313 double epsilon) { 314 float A[] = {8.0f, 16.0f, 24.0f, 32.0f, 40.0f, 48.0f}; 315 float B[] = {2.0f, 0.0f, 0.0f, 2.0f, 0.0f, 0.0f}; 316 float C[4], E[4]; 317 318 ref_f32_f32_f32(2, 3, 2, A, B, E, 4.0); 319 matmul_fn(2, 3, 2, A, B, C, 4.0); 320 321 for (int i = 0; i < 4; i++) { 322 double d = (double)E[i] - (double)C[i]; 323 if (d < 0) d = -d; 324 if (d > epsilon) return MUNIT_FAIL; 325 } 326 return MUNIT_OK; 327 } 328 329 static MunitResult test_f32_f32_f32_scaled_medium(const char *name, 330 int (*matmul_fn)(size_t, size_t, size_t, const float *, const float *, 331 float *, double), 332 double epsilon) { 333 const size_t m = 64, n = 64, p = 64; 334 float *A = malloc(m * n * sizeof(float)); 335 float *B = malloc(n * p * sizeof(float)); 336 float *C = malloc(m * p * sizeof(float)); 337 float *E = malloc(m * p * sizeof(float)); 338 if (!A || !B || !C || !E) { 339 free(A); 340 free(B); 341 free(C); 342 free(E); 343 return MUNIT_SKIP; 344 } 345 346 for (size_t i = 0; i < m * n; i++) A[i] = (float)((i * 7 + 13) % 251); 347 for (size_t i = 0; i < n * p; i++) B[i] = (float)(((i * 11 + 17) % 211) - 105); 348 memset(C, 0, m * p * sizeof(float)); 349 memset(E, 0, m * p * sizeof(float)); 350 351 ref_f32_f32_f32(m, n, p, A, B, E, 8.0); 352 matmul_fn(m, n, p, A, B, C, 8.0); 353 354 for (size_t i = 0; i < m * p; i++) { 355 double d = (double)E[i] - (double)C[i]; 356 if (d < 0) d = -d; 357 if (d > epsilon) { 358 free(A); 359 free(B); 360 free(C); 361 free(E); 362 return MUNIT_FAIL; 363 } 364 } 365 366 free(A); 367 free(B); 368 free(C); 369 free(E); 370 return MUNIT_OK; 371 } 372 373 static MunitResult test_scalar_f32_f32_f32(const MunitParameter *params, void *data) { 374 (void)params; 375 (void)data; 376 return test_f32_f32_f32_small("scalar", matmul_scalar_f32_f32_f32, 1e-5); 377 } 378 379 static MunitResult test_scalar_f32_f32_f32_medium(const MunitParameter *params, void *data) { 380 (void)params; 381 (void)data; 382 return test_f32_f32_f32_medium("scalar", matmul_scalar_f32_f32_f32, 1e-3); 383 } 384 385 static MunitResult test_scalar_f32_f32_f32_scaled(const MunitParameter *params, void *data) { 386 (void)params; 387 (void)data; 388 return test_f32_f32_f32_scaled_small("scalar", matmul_scalar_f32_f32_f32, 1e-5); 389 } 390 391 static MunitResult test_scalar_f32_f32_f32_scaled_medium(const MunitParameter *params, void *data) { 392 (void)params; 393 (void)data; 394 return test_f32_f32_f32_scaled_medium("scalar", matmul_scalar_f32_f32_f32, 1e-3); 395 } 396 397 #ifdef __AVX2__ 398 static MunitResult test_avx2_f32_f32_f32(const MunitParameter *params, void *data) { 399 (void)params; 400 (void)data; 401 return test_f32_f32_f32_small("avx2", matmul_avx2_f32_f32_f32, 1e-5); 402 } 403 404 static MunitResult test_avx2_f32_f32_f32_medium(const MunitParameter *params, void *data) { 405 (void)params; 406 (void)data; 407 return test_f32_f32_f32_medium("avx2", matmul_avx2_f32_f32_f32, 1e-3); 408 } 409 410 static MunitResult test_avx2_f32_f32_f32_scaled(const MunitParameter *params, void *data) { 411 (void)params; 412 (void)data; 413 return test_f32_f32_f32_scaled_small("avx2", matmul_avx2_f32_f32_f32, 1e-5); 414 } 415 416 static MunitResult test_avx2_f32_f32_f32_scaled_medium(const MunitParameter *params, void *data) { 417 (void)params; 418 (void)data; 419 return test_f32_f32_f32_scaled_medium("avx2", matmul_avx2_f32_f32_f32, 1e-3); 420 } 421 #endif 422 423 #ifdef __AVX512F__ 424 static MunitResult test_avx512_f32_f32_f32(const MunitParameter *params, void *data) { 425 (void)params; 426 (void)data; 427 return test_f32_f32_f32_small("avx512", matmul_avx512_f32_f32_f32, 1e-5); 428 } 429 430 static MunitResult test_avx512_f32_f32_f32_medium(const MunitParameter *params, void *data) { 431 (void)params; 432 (void)data; 433 return test_f32_f32_f32_medium("avx512", matmul_avx512_f32_f32_f32, 1e-3); 434 } 435 436 static MunitResult test_avx512_f32_f32_f32_scaled(const MunitParameter *params, void *data) { 437 (void)params; 438 (void)data; 439 return test_f32_f32_f32_scaled_small("avx512", matmul_avx512_f32_f32_f32, 1e-5); 440 } 441 442 static MunitResult test_avx512_f32_f32_f32_scaled_medium(const MunitParameter *params, void *data) { 443 (void)params; 444 (void)data; 445 return test_f32_f32_f32_scaled_medium("avx512", matmul_avx512_f32_f32_f32, 1e-3); 446 } 447 #endif 448 449 static MunitResult test_dispatched_f32_f32_f32(const MunitParameter *params, void *data) { 450 (void)params; 451 (void)data; 452 return test_f32_f32_f32_small("dispatched", matmul_f32_f32_f32, 1e-5); 453 } 454 455 static MunitResult test_dispatched_f32_f32_f32_medium(const MunitParameter *params, void *data) { 456 (void)params; 457 (void)data; 458 return test_f32_f32_f32_medium("dispatched", matmul_f32_f32_f32, 1e-3); 459 } 460 461 static MunitResult test_dispatched_f32_f32_f32_scaled(const MunitParameter *params, void *data) { 462 (void)params; 463 (void)data; 464 return test_f32_f32_f32_scaled_small("dispatched", matmul_f32_f32_f32, 1e-5); 465 } 466 467 static MunitResult test_dispatched_f32_f32_f32_scaled_medium(const MunitParameter *params, void *data) { 468 (void)params; 469 (void)data; 470 return test_f32_f32_f32_scaled_medium("dispatched", matmul_f32_f32_f32, 1e-3); 471 } 472 473 /* ========================================================================== */ 474 /* f64_f64_f64 tests */ 475 /* ========================================================================== */ 476 477 static void ref_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, const double *B, double *C, double scale) { 478 for (size_t i = 0; i < m; i++) 479 for (size_t j = 0; j < p; j++) { 480 double sum = 0.0; 481 for (size_t k = 0; k < n; k++) sum += A[i * n + k] * B[k * p + j]; 482 if (scale > 1.0) sum /= scale; 483 C[i * p + j] = sum; 484 } 485 } 486 487 static MunitResult test_f64_f64_f64_small(const char *name, 488 int (*matmul_fn)(size_t, size_t, size_t, const double *, const double *, 489 double *, double), 490 double epsilon) { 491 double A[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}; 492 double B[] = {1.0, 0.0, 0.0, 1.0, 0.0, 0.0}; 493 double C[4], E[4]; 494 495 ref_f64_f64_f64(2, 3, 2, A, B, E, 0.0); 496 matmul_fn(2, 3, 2, A, B, C, 0.0); 497 498 for (int i = 0; i < 4; i++) { 499 double d = E[i] - C[i]; 500 if (d < 0) d = -d; 501 if (d > epsilon) return MUNIT_FAIL; 502 } 503 return MUNIT_OK; 504 } 505 506 static MunitResult test_f64_f64_f64_medium(const char *name, 507 int (*matmul_fn)(size_t, size_t, size_t, const double *, const double *, 508 double *, double), 509 double epsilon) { 510 const size_t m = 64, n = 64, p = 64; 511 double *A = malloc(m * n * sizeof(double)); 512 double *B = malloc(n * p * sizeof(double)); 513 double *C = malloc(m * p * sizeof(double)); 514 double *E = malloc(m * p * sizeof(double)); 515 if (!A || !B || !C || !E) { 516 free(A); 517 free(B); 518 free(C); 519 free(E); 520 return MUNIT_SKIP; 521 } 522 523 for (size_t i = 0; i < m * n; i++) A[i] = (double)((i * 7 + 13) % 251); 524 for (size_t i = 0; i < n * p; i++) B[i] = (double)(((i * 11 + 17) % 211) - 105); 525 memset(C, 0, m * p * sizeof(double)); 526 memset(E, 0, m * p * sizeof(double)); 527 528 ref_f64_f64_f64(m, n, p, A, B, E, 0.0); 529 matmul_fn(m, n, p, A, B, C, 0.0); 530 531 for (size_t i = 0; i < m * p; i++) { 532 double d = E[i] - C[i]; 533 if (d < 0) d = -d; 534 if (d > epsilon) { 535 free(A); 536 free(B); 537 free(C); 538 free(E); 539 return MUNIT_FAIL; 540 } 541 } 542 543 free(A); 544 free(B); 545 free(C); 546 free(E); 547 return MUNIT_OK; 548 } 549 550 static MunitResult test_f64_f64_f64_scaled_small(const char *name, 551 int (*matmul_fn)(size_t, size_t, size_t, const double *, 552 const double *, double *, double), 553 double epsilon) { 554 double A[] = {8.0, 16.0, 24.0, 32.0, 40.0, 48.0}; 555 double B[] = {2.0, 0.0, 0.0, 2.0, 0.0, 0.0}; 556 double C[4], E[4]; 557 558 ref_f64_f64_f64(2, 3, 2, A, B, E, 4.0); 559 matmul_fn(2, 3, 2, A, B, C, 4.0); 560 561 for (int i = 0; i < 4; i++) { 562 double d = E[i] - C[i]; 563 if (d < 0) d = -d; 564 if (d > epsilon) return MUNIT_FAIL; 565 } 566 return MUNIT_OK; 567 } 568 569 static MunitResult test_f64_f64_f64_scaled_medium(const char *name, 570 int (*matmul_fn)(size_t, size_t, size_t, const double *, 571 const double *, double *, double), 572 double epsilon) { 573 const size_t m = 64, n = 64, p = 64; 574 double *A = malloc(m * n * sizeof(double)); 575 double *B = malloc(n * p * sizeof(double)); 576 double *C = malloc(m * p * sizeof(double)); 577 double *E = malloc(m * p * sizeof(double)); 578 if (!A || !B || !C || !E) { 579 free(A); 580 free(B); 581 free(C); 582 free(E); 583 return MUNIT_SKIP; 584 } 585 586 for (size_t i = 0; i < m * n; i++) A[i] = (double)((i * 7 + 13) % 251); 587 for (size_t i = 0; i < n * p; i++) B[i] = (double)(((i * 11 + 17) % 211) - 105); 588 memset(C, 0, m * p * sizeof(double)); 589 memset(E, 0, m * p * sizeof(double)); 590 591 ref_f64_f64_f64(m, n, p, A, B, E, 8.0); 592 matmul_fn(m, n, p, A, B, C, 8.0); 593 594 for (size_t i = 0; i < m * p; i++) { 595 double d = E[i] - C[i]; 596 if (d < 0) d = -d; 597 if (d > epsilon) { 598 free(A); 599 free(B); 600 free(C); 601 free(E); 602 return MUNIT_FAIL; 603 } 604 } 605 606 free(A); 607 free(B); 608 free(C); 609 free(E); 610 return MUNIT_OK; 611 } 612 613 static MunitResult test_scalar_f64_f64_f64(const MunitParameter *params, void *data) { 614 (void)params; 615 (void)data; 616 return test_f64_f64_f64_small("scalar", matmul_scalar_f64_f64_f64, 1e-12); 617 } 618 619 static MunitResult test_scalar_f64_f64_f64_medium(const MunitParameter *params, void *data) { 620 (void)params; 621 (void)data; 622 return test_f64_f64_f64_medium("scalar", matmul_scalar_f64_f64_f64, 1e-9); 623 } 624 625 static MunitResult test_scalar_f64_f64_f64_scaled(const MunitParameter *params, void *data) { 626 (void)params; 627 (void)data; 628 return test_f64_f64_f64_scaled_small("scalar", matmul_scalar_f64_f64_f64, 1e-12); 629 } 630 631 static MunitResult test_scalar_f64_f64_f64_scaled_medium(const MunitParameter *params, void *data) { 632 (void)params; 633 (void)data; 634 return test_f64_f64_f64_scaled_medium("scalar", matmul_scalar_f64_f64_f64, 1e-9); 635 } 636 637 #ifdef __AVX2__ 638 static MunitResult test_avx2_f64_f64_f64(const MunitParameter *params, void *data) { 639 (void)params; 640 (void)data; 641 return test_f64_f64_f64_small("avx2", matmul_avx2_f64_f64_f64, 1e-12); 642 } 643 644 static MunitResult test_avx2_f64_f64_f64_medium(const MunitParameter *params, void *data) { 645 (void)params; 646 (void)data; 647 return test_f64_f64_f64_medium("avx2", matmul_avx2_f64_f64_f64, 1e-9); 648 } 649 650 static MunitResult test_avx2_f64_f64_f64_scaled(const MunitParameter *params, void *data) { 651 (void)params; 652 (void)data; 653 return test_f64_f64_f64_scaled_small("avx2", matmul_avx2_f64_f64_f64, 1e-12); 654 } 655 656 static MunitResult test_avx2_f64_f64_f64_scaled_medium(const MunitParameter *params, void *data) { 657 (void)params; 658 (void)data; 659 return test_f64_f64_f64_scaled_medium("avx2", matmul_avx2_f64_f64_f64, 1e-9); 660 } 661 #endif 662 663 #ifdef __AVX512F__ 664 static MunitResult test_avx512_f64_f64_f64(const MunitParameter *params, void *data) { 665 (void)params; 666 (void)data; 667 return test_f64_f64_f64_small("avx512", matmul_avx512_f64_f64_f64, 1e-12); 668 } 669 670 static MunitResult test_avx512_f64_f64_f64_medium(const MunitParameter *params, void *data) { 671 (void)params; 672 (void)data; 673 return test_f64_f64_f64_medium("avx512", matmul_avx512_f64_f64_f64, 1e-9); 674 } 675 676 static MunitResult test_avx512_f64_f64_f64_scaled(const MunitParameter *params, void *data) { 677 (void)params; 678 (void)data; 679 return test_f64_f64_f64_scaled_small("avx512", matmul_avx512_f64_f64_f64, 1e-12); 680 } 681 682 static MunitResult test_avx512_f64_f64_f64_scaled_medium(const MunitParameter *params, void *data) { 683 (void)params; 684 (void)data; 685 return test_f64_f64_f64_scaled_medium("avx512", matmul_avx512_f64_f64_f64, 1e-9); 686 } 687 #endif 688 689 static MunitResult test_dispatched_f64_f64_f64(const MunitParameter *params, void *data) { 690 (void)params; 691 (void)data; 692 return test_f64_f64_f64_small("dispatched", matmul_f64_f64_f64, 1e-12); 693 } 694 695 static MunitResult test_dispatched_f64_f64_f64_medium(const MunitParameter *params, void *data) { 696 (void)params; 697 (void)data; 698 return test_f64_f64_f64_medium("dispatched", matmul_f64_f64_f64, 1e-9); 699 } 700 701 static MunitResult test_dispatched_f64_f64_f64_scaled(const MunitParameter *params, void *data) { 702 (void)params; 703 (void)data; 704 return test_f64_f64_f64_scaled_small("dispatched", matmul_f64_f64_f64, 1e-12); 705 } 706 707 static MunitResult test_dispatched_f64_f64_f64_scaled_medium(const MunitParameter *params, void *data) { 708 (void)params; 709 (void)data; 710 return test_f64_f64_f64_scaled_medium("dispatched", matmul_f64_f64_f64, 1e-9); 711 } 712 713 /* ========================================================================== */ 714 /* matmul generic macro tests */ 715 /* ========================================================================== */ 716 717 static MunitResult test_matmul_u8_i8_u8(const char *name, const char *type) { 718 const uint8_t A[] = {1, 2, 3, 4, 5, 6}; 719 const int8_t B[] = {1, 0, 0, 1, 0, 0}; 720 uint8_t C[4], E[4]; 721 722 ref_u8_i8_u8(2, 3, 2, A, B, E, 0.0); 723 matmul(2, 3, 2, A, B, C, 0.0); 724 725 for (int i = 0; i < 4; i++) { 726 int d = E[i] - C[i]; 727 if (d < 0) d = -d; 728 if (d > 0) return MUNIT_FAIL; 729 } 730 return MUNIT_OK; 731 } 732 733 static MunitResult test_matmul_u8_i8_u8_scaled(const char *name, const char *type) { 734 const uint8_t A[] = {8, 16, 24, 32, 40, 48}; 735 const int8_t B[] = {2, 0, 0, 2, 0, 0}; 736 uint8_t C[4], E[4]; 737 738 ref_u8_i8_u8(2, 3, 2, A, B, E, 4.0); 739 matmul(2, 3, 2, A, B, C, 4.0); 740 741 for (int i = 0; i < 4; i++) { 742 int d = E[i] - C[i]; 743 if (d < 0) d = -d; 744 if (d > 0) return MUNIT_FAIL; 745 } 746 return MUNIT_OK; 747 } 748 749 static MunitResult test_matmul_u8_i8_u8_medium(const char *name, const char *type) { 750 const size_t m = 64, n = 64, p = 64; 751 uint8_t *A = malloc(m * n); 752 int8_t *B = malloc(n * p); 753 uint8_t *C = malloc(m * p); 754 uint8_t *E = malloc(m * p); 755 if (!A || !B || !C || !E) { 756 free(A); 757 free(B); 758 free(C); 759 free(E); 760 return MUNIT_SKIP; 761 } 762 763 for (size_t i = 0; i < m * n; i++) A[i] = (uint8_t)((i * 7 + 13) % 251); 764 for (size_t i = 0; i < n * p; i++) B[i] = (int8_t)(((i * 11 + 17) % 211) - 105); 765 memset(C, 0, m * p); 766 memset(E, 0, m * p); 767 768 ref_u8_i8_u8(m, n, p, A, B, E, 0.0); 769 matmul(m, n, p, A, B, C, 0.0); 770 771 for (size_t i = 0; i < m * p; i++) { 772 int d = (int)E[i] - (int)C[i]; 773 if (d < 0) d = -d; 774 if (d > 0) { 775 free(A); 776 free(B); 777 free(C); 778 free(E); 779 return MUNIT_FAIL; 780 } 781 } 782 783 free(A); 784 free(B); 785 free(C); 786 free(E); 787 return MUNIT_OK; 788 } 789 790 static MunitResult test_matmul_f32_f32_f32(const char *name, const char *type) { 791 const float A[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; 792 const float B[] = {1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f}; 793 float C[4], E[4]; 794 795 ref_f32_f32_f32(2, 3, 2, A, B, E, 0.0); 796 matmul(2, 3, 2, A, B, C, 0.0); 797 798 for (int i = 0; i < 4; i++) { 799 double d = (double)E[i] - (double)C[i]; 800 if (d < 0) d = -d; 801 if (d > 1e-5) return MUNIT_FAIL; 802 } 803 return MUNIT_OK; 804 } 805 806 static MunitResult test_matmul_f32_f32_f32_scaled(const char *name, const char *type) { 807 const float A[] = {8.0f, 16.0f, 24.0f, 32.0f, 40.0f, 48.0f}; 808 const float B[] = {2.0f, 0.0f, 0.0f, 2.0f, 0.0f, 0.0f}; 809 float C[4], E[4]; 810 811 ref_f32_f32_f32(2, 3, 2, A, B, E, 4.0); 812 matmul(2, 3, 2, A, B, C, 4.0); 813 814 for (int i = 0; i < 4; i++) { 815 double d = (double)E[i] - (double)C[i]; 816 if (d < 0) d = -d; 817 if (d > 1e-5) return MUNIT_FAIL; 818 } 819 return MUNIT_OK; 820 } 821 822 static MunitResult test_matmul_f32_f32_f32_medium(const char *name, const char *type) { 823 const size_t m = 64, n = 64, p = 64; 824 float *A = malloc(m * n * sizeof(float)); 825 float *B = malloc(n * p * sizeof(float)); 826 float *C = malloc(m * p * sizeof(float)); 827 float *E = malloc(m * p * sizeof(float)); 828 if (!A || !B || !C || !E) { 829 free(A); 830 free(B); 831 free(C); 832 free(E); 833 return MUNIT_SKIP; 834 } 835 836 for (size_t i = 0; i < m * n; i++) A[i] = (float)((i * 7 + 13) % 251); 837 for (size_t i = 0; i < n * p; i++) B[i] = (float)(((i * 11 + 17) % 211) - 105); 838 memset(C, 0, m * p * sizeof(float)); 839 memset(E, 0, m * p * sizeof(float)); 840 841 ref_f32_f32_f32(m, n, p, A, B, E, 0.0); 842 matmul(m, n, p, A, B, C, 0.0); 843 844 for (size_t i = 0; i < m * p; i++) { 845 double d = (double)E[i] - (double)C[i]; 846 if (d < 0) d = -d; 847 if (d > 1e-3) { 848 free(A); 849 free(B); 850 free(C); 851 free(E); 852 return MUNIT_FAIL; 853 } 854 } 855 856 free(A); 857 free(B); 858 free(C); 859 free(E); 860 return MUNIT_OK; 861 } 862 863 static MunitResult test_matmul_f64_f64_f64(const char *name, const char *type) { 864 const double A[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}; 865 const double B[] = {1.0, 0.0, 0.0, 1.0, 0.0, 0.0}; 866 double C[4], E[4]; 867 868 ref_f64_f64_f64(2, 3, 2, A, B, E, 0.0); 869 matmul(2, 3, 2, A, B, C, 0.0); 870 871 for (int i = 0; i < 4; i++) { 872 double d = E[i] - C[i]; 873 if (d < 0) d = -d; 874 if (d > 1e-12) return MUNIT_FAIL; 875 } 876 return MUNIT_OK; 877 } 878 879 static MunitResult test_matmul_f64_f64_f64_scaled(const char *name, const char *type) { 880 const double A[] = {8.0, 16.0, 24.0, 32.0, 40.0, 48.0}; 881 const double B[] = {2.0, 0.0, 0.0, 2.0, 0.0, 0.0}; 882 double C[4], E[4]; 883 884 ref_f64_f64_f64(2, 3, 2, A, B, E, 4.0); 885 matmul(2, 3, 2, A, B, C, 4.0); 886 887 for (int i = 0; i < 4; i++) { 888 double d = E[i] - C[i]; 889 if (d < 0) d = -d; 890 if (d > 1e-12) return MUNIT_FAIL; 891 } 892 return MUNIT_OK; 893 } 894 895 static MunitResult test_matmul_f64_f64_f64_medium(const char *name, const char *type) { 896 const size_t m = 64, n = 64, p = 64; 897 double *A = malloc(m * n * sizeof(double)); 898 double *B = malloc(n * p * sizeof(double)); 899 double *C = malloc(m * p * sizeof(double)); 900 double *E = malloc(m * p * sizeof(double)); 901 if (!A || !B || !C || !E) { 902 free(A); 903 free(B); 904 free(C); 905 free(E); 906 return MUNIT_SKIP; 907 } 908 909 for (size_t i = 0; i < m * n; i++) A[i] = (double)((i * 7 + 13) % 251); 910 for (size_t i = 0; i < n * p; i++) B[i] = (double)(((i * 11 + 17) % 211) - 105); 911 memset(C, 0, m * p * sizeof(double)); 912 memset(E, 0, m * p * sizeof(double)); 913 914 ref_f64_f64_f64(m, n, p, A, B, E, 0.0); 915 matmul(m, n, p, A, B, C, 0.0); 916 917 for (size_t i = 0; i < m * p; i++) { 918 double d = E[i] - C[i]; 919 if (d < 0) d = -d; 920 if (d > 1e-9) { 921 free(A); 922 free(B); 923 free(C); 924 free(E); 925 return MUNIT_FAIL; 926 } 927 } 928 929 free(A); 930 free(B); 931 free(C); 932 free(E); 933 return MUNIT_OK; 934 } 935 936 static MunitResult test_matmul_u8_i8_u8_const(const MunitParameter *params, void *data) { 937 (void)params; 938 (void)data; 939 return test_matmul_u8_i8_u8("const", "u8_i8_u8"); 940 } 941 942 static MunitResult test_matmul_u8_i8_u8_nonconst(const MunitParameter *params, void *data) { 943 (void)params; 944 (void)data; 945 uint8_t A[] = {1, 2, 3, 4, 5, 6}; 946 int8_t B[] = {1, 0, 0, 1, 0, 0}; 947 uint8_t C[4], E[4]; 948 949 ref_u8_i8_u8(2, 3, 2, A, B, E, 0.0); 950 matmul(2, 3, 2, A, B, C, 0.0); 951 952 for (int i = 0; i < 4; i++) { 953 int d = E[i] - C[i]; 954 if (d < 0) d = -d; 955 if (d > 0) return MUNIT_FAIL; 956 } 957 return MUNIT_OK; 958 } 959 960 static MunitResult test_matmul_u8_i8_u8_scaled_const(const MunitParameter *params, void *data) { 961 (void)params; 962 (void)data; 963 return test_matmul_u8_i8_u8_scaled("const", "u8_i8_u8"); 964 } 965 966 static MunitResult test_matmul_u8_i8_u8_medium_const(const MunitParameter *params, void *data) { 967 (void)params; 968 (void)data; 969 return test_matmul_u8_i8_u8_medium("const", "u8_i8_u8"); 970 } 971 972 static MunitResult test_matmul_f32_f32_f32_const(const MunitParameter *params, void *data) { 973 (void)params; 974 (void)data; 975 return test_matmul_f32_f32_f32("const", "f32_f32_f32"); 976 } 977 978 static int matmul_f32_f32_f32_nonconst(size_t m, size_t n, size_t p, float *A, float *B, float *C, double scale) { 979 return matmul(m, n, p, A, B, C, scale); 980 } 981 982 static MunitResult test_matmul_f32_f32_f32_nonconst(const MunitParameter *params, void *data) { 983 (void)params; 984 (void)data; 985 float A[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; 986 float B[] = {1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f}; 987 float C[4], E[4]; 988 989 ref_f32_f32_f32(2, 3, 2, A, B, E, 0.0); 990 matmul_f32_f32_f32_nonconst(2, 3, 2, A, B, C, 0.0); 991 992 for (int i = 0; i < 4; i++) { 993 double d = (double)E[i] - (double)C[i]; 994 if (d < 0) d = -d; 995 if (d > 1e-5) return MUNIT_FAIL; 996 } 997 return MUNIT_OK; 998 } 999 1000 static MunitResult test_matmul_f32_f32_f32_scaled_const(const MunitParameter *params, void *data) { 1001 (void)params; 1002 (void)data; 1003 return test_matmul_f32_f32_f32_scaled("const", "f32_f32_f32"); 1004 } 1005 1006 static MunitResult test_matmul_f32_f32_f32_medium_const(const MunitParameter *params, void *data) { 1007 (void)params; 1008 (void)data; 1009 return test_matmul_f32_f32_f32_medium("const", "f32_f32_f32"); 1010 } 1011 1012 static MunitResult test_matmul_f64_f64_f64_const(const MunitParameter *params, void *data) { 1013 (void)params; 1014 (void)data; 1015 return test_matmul_f64_f64_f64("const", "f64_f64_f64"); 1016 } 1017 1018 static int matmul_f64_f64_f64_nonconst(size_t m, size_t n, size_t p, double *A, double *B, double *C, double scale) { 1019 return matmul(m, n, p, A, B, C, scale); 1020 } 1021 1022 static MunitResult test_matmul_f64_f64_f64_nonconst(const MunitParameter *params, void *data) { 1023 (void)params; 1024 (void)data; 1025 double A[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}; 1026 double B[] = {1.0, 0.0, 0.0, 1.0, 0.0, 0.0}; 1027 double C[4], E[4]; 1028 1029 ref_f64_f64_f64(2, 3, 2, A, B, E, 0.0); 1030 matmul_f64_f64_f64_nonconst(2, 3, 2, A, B, C, 0.0); 1031 1032 for (int i = 0; i < 4; i++) { 1033 double d = E[i] - C[i]; 1034 if (d < 0) d = -d; 1035 if (d > 1e-12) return MUNIT_FAIL; 1036 } 1037 return MUNIT_OK; 1038 } 1039 1040 static MunitResult test_matmul_f64_f64_f64_scaled_const(const MunitParameter *params, void *data) { 1041 (void)params; 1042 (void)data; 1043 return test_matmul_f64_f64_f64_scaled("const", "f64_f64_f64"); 1044 } 1045 1046 static MunitResult test_matmul_f64_f64_f64_medium_const(const MunitParameter *params, void *data) { 1047 (void)params; 1048 (void)data; 1049 return test_matmul_f64_f64_f64_medium("const", "f64_f64_f64"); 1050 } 1051 1052 static MunitTest tests[] = { 1053 {"/scalar-u8-i8-u8", test_scalar_u8_i8_u8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, 1054 {"/scalar-u8-i8-u8-medium", test_scalar_u8_i8_u8_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, 1055 {"/scalar-u8-i8-u8-scaled", test_scalar_u8_i8_u8_scaled, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, 1056 {"/scalar-u8-i8-u8-scaled-medium", test_scalar_u8_i8_u8_scaled_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, 1057 #ifdef __AVX512VNNI__ 1058 {"/avx512vnni-u8-i8-u8", test_avx512vnni_u8_i8_u8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, 1059 {"/avx512vnni-u8-i8-u8-medium", test_avx512vnni_u8_i8_u8_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, 1060 {"/avx512vnni-u8-i8-u8-scaled", test_avx512vnni_u8_i8_u8_scaled, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, 1061 {"/avx512vnni-u8-i8-u8-scaled-medium", test_avx512vnni_u8_i8_u8_scaled_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, 1062 NULL}, 1063 #endif 1064 {"/dispatched-u8-i8-u8", test_dispatched_u8_i8_u8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, 1065 {"/dispatched-u8-i8-u8-medium", test_dispatched_u8_i8_u8_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, 1066 {"/dispatched-u8-i8-u8-scaled", test_dispatched_u8_i8_u8_scaled, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, 1067 {"/dispatched-u8-i8-u8-scaled-medium", test_dispatched_u8_i8_u8_scaled_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, 1068 NULL}, 1069 {"/scalar-f32-f32-f32", test_scalar_f32_f32_f32, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, 1070 {"/scalar-f32-f32-f32-medium", test_scalar_f32_f32_f32_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, 1071 {"/scalar-f32-f32-f32-scaled", test_scalar_f32_f32_f32_scaled, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, 1072 {"/scalar-f32-f32-f32-scaled-medium", test_scalar_f32_f32_f32_scaled_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, 1073 NULL}, 1074 #ifdef __AVX2__ 1075 {"/avx2-f32-f32-f32", test_avx2_f32_f32_f32, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, 1076 {"/avx2-f32-f32-f32-medium", test_avx2_f32_f32_f32_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, 1077 {"/avx2-f32-f32-f32-scaled", test_avx2_f32_f32_f32_scaled, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, 1078 {"/avx2-f32-f32-f32-scaled-medium", test_avx2_f32_f32_f32_scaled_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, 1079 #endif 1080 #ifdef __AVX512F__ 1081 {"/avx512-f32-f32-f32", test_avx512_f32_f32_f32, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, 1082 {"/avx512-f32-f32-f32-medium", test_avx512_f32_f32_f32_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, 1083 {"/avx512-f32-f32-f32-scaled", test_avx512_f32_f32_f32_scaled, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, 1084 {"/avx512-f32-f32-f32-scaled-medium", test_avx512_f32_f32_f32_scaled_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, 1085 NULL}, 1086 #endif 1087 {"/dispatched-f32-f32-f32", test_dispatched_f32_f32_f32, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, 1088 {"/dispatched-f32-f32-f32-medium", test_dispatched_f32_f32_f32_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, 1089 {"/dispatched-f32-f32-f32-scaled", test_dispatched_f32_f32_f32_scaled, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, 1090 {"/dispatched-f32-f32-f32-scaled-medium", test_dispatched_f32_f32_f32_scaled_medium, NULL, NULL, 1091 MUNIT_TEST_OPTION_NONE, NULL}, 1092 {"/scalar-f64-f64-f64", test_scalar_f64_f64_f64, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, 1093 {"/scalar-f64-f64-f64-medium", test_scalar_f64_f64_f64_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, 1094 {"/scalar-f64-f64-f64-scaled", test_scalar_f64_f64_f64_scaled, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, 1095 {"/scalar-f64-f64-f64-scaled-medium", test_scalar_f64_f64_f64_scaled_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, 1096 NULL}, 1097 #ifdef __AVX2__ 1098 {"/avx2-f64-f64-f64", test_avx2_f64_f64_f64, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, 1099 {"/avx2-f64-f64-f64-medium", test_avx2_f64_f64_f64_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, 1100 {"/avx2-f64-f64-f64-scaled", test_avx2_f64_f64_f64_scaled, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, 1101 {"/avx2-f64-f64-f64-scaled-medium", test_avx2_f64_f64_f64_scaled_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, 1102 #endif 1103 #ifdef __AVX512F__ 1104 {"/avx512-f64-f64-f64", test_avx512_f64_f64_f64, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, 1105 {"/avx512-f64-f64-f64-medium", test_avx512_f64_f64_f64_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, 1106 {"/avx512-f64-f64-f64-scaled", test_avx512_f64_f64_f64_scaled, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, 1107 {"/avx512-f64-f64-f64-scaled-medium", test_avx512_f64_f64_f64_scaled_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, 1108 NULL}, 1109 #endif 1110 {"/dispatched-f64-f64-f64", test_dispatched_f64_f64_f64, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, 1111 {"/dispatched-f64-f64-f64-medium", test_dispatched_f64_f64_f64_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, 1112 {"/dispatched-f64-f64-f64-scaled", test_dispatched_f64_f64_f64_scaled, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, 1113 {"/dispatched-f64-f64-f64-scaled-medium", test_dispatched_f64_f64_f64_scaled_medium, NULL, NULL, 1114 MUNIT_TEST_OPTION_NONE, NULL}, 1115 {"/matmul-u8-i8-u8-const", test_matmul_u8_i8_u8_const, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, 1116 {"/matmul-u8-i8-u8-nonconst", test_matmul_u8_i8_u8_nonconst, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, 1117 {"/matmul-u8-i8-u8-scaled-const", test_matmul_u8_i8_u8_scaled_const, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, 1118 {"/matmul-u8-i8-u8-medium-const", test_matmul_u8_i8_u8_medium_const, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, 1119 {"/matmul-f32-f32-f32-const", test_matmul_f32_f32_f32_const, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, 1120 {"/matmul-f32-f32-f32-nonconst", test_matmul_f32_f32_f32_nonconst, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, 1121 {"/matmul-f32-f32-f32-scaled-const", test_matmul_f32_f32_f32_scaled_const, NULL, NULL, MUNIT_TEST_OPTION_NONE, 1122 NULL}, 1123 {"/matmul-f32-f32-f32-medium-const", test_matmul_f32_f32_f32_medium_const, NULL, NULL, MUNIT_TEST_OPTION_NONE, 1124 NULL}, 1125 {"/matmul-f64-f64-f64-const", test_matmul_f64_f64_f64_const, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, 1126 {"/matmul-f64-f64-f64-nonconst", test_matmul_f64_f64_f64_nonconst, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, 1127 {"/matmul-f64-f64-f64-scaled-const", test_matmul_f64_f64_f64_scaled_const, NULL, NULL, MUNIT_TEST_OPTION_NONE, 1128 NULL}, 1129 {"/matmul-f64-f64-f64-medium-const", test_matmul_f64_f64_f64_medium_const, NULL, NULL, MUNIT_TEST_OPTION_NONE, 1130 NULL}, 1131 {NULL, NULL, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}}; 1132 1133 static const MunitSuite suite = {"/matmul", tests, NULL, 1, MUNIT_SUITE_OPTION_NONE}; 1134 1135 int main(int argc, char *argv[MUNIT_ARRAY_PARAM(argc)]) { 1136 return munit_suite_main(&suite, NULL, argc, argv); 1137 }