matmul.c

Matrix multiplication helper library
git clone git://git.finwo.net/lib/matmul.c
Log | Files | Refs | README | LICENSE

test_matmul.c (38885B)


      1 #include <math.h>
      2 #include <stdint.h>
      3 #include <stdio.h>
      4 #include <stdlib.h>
      5 #include <string.h>
      6 
      7 #include "../src/matmul.h"
      8 #include "nemequ/munit.h"
      9 #include "test_matmul_simd.h"
     10 
     11 static void ref_u8_i8_u8(size_t m, size_t n, size_t p, const uint8_t *A, const int8_t *B, uint8_t *C, double scale) {
     12   for (size_t i = 0; i < m; i++)
     13     for (size_t j = 0; j < p; j++) {
     14       int sum = 0;
     15       for (size_t k = 0; k < n; k++) sum += (int)A[i * n + k] * (int)B[k * p + j];
     16       if (scale > 1.0) sum = (int)(sum / scale);
     17       if (sum > 255) sum = 255;
     18       if (sum < 0) sum = 0;
     19       C[i * p + j] = (uint8_t)sum;
     20     }
     21 }
     22 
     23 static MunitResult test_u8_i8_u8_small(const char *name,
     24                                        int (*matmul_fn)(size_t, size_t, size_t, const uint8_t *, const int8_t *,
     25                                                         uint8_t *, double),
     26                                        double epsilon) {
     27   uint8_t A[] = {1, 2, 3, 4, 5, 6};
     28   int8_t  B[] = {1, 0, 0, 1, 0, 0};
     29   uint8_t C[4], E[4];
     30 
     31   ref_u8_i8_u8(2, 3, 2, A, B, E, 0.0);
     32   matmul_fn(2, 3, 2, A, B, C, 0.0);
     33 
     34   for (int i = 0; i < 4; i++) {
     35     int d = E[i] - C[i];
     36     if (d < 0) d = -d;
     37     if (d > (int)epsilon) return MUNIT_FAIL;
     38   }
     39   return MUNIT_OK;
     40 }
     41 
     42 static MunitResult test_u8_i8_u8_medium(const char *name,
     43                                         int (*matmul_fn)(size_t, size_t, size_t, const uint8_t *, const int8_t *,
     44                                                          uint8_t *, double),
     45                                         double epsilon) {
     46   const size_t m = 64, n = 64, p = 64;
     47   uint8_t     *A = malloc(m * n);
     48   int8_t      *B = malloc(n * p);
     49   uint8_t     *C = malloc(m * p);
     50   uint8_t     *E = malloc(m * p);
     51   if (!A || !B || !C || !E) {
     52     free(A);
     53     free(B);
     54     free(C);
     55     free(E);
     56     return MUNIT_SKIP;
     57   }
     58 
     59   // Deterministic pseudo-random values
     60   for (size_t i = 0; i < m * n; i++) A[i] = (uint8_t)((i * 7 + 13) % 251);
     61   for (size_t i = 0; i < n * p; i++) B[i] = (int8_t)(((i * 11 + 17) % 211) - 105);
     62   memset(C, 0, m * p);
     63   memset(E, 0, m * p);
     64 
     65   ref_u8_i8_u8(m, n, p, A, B, E, 0.0);
     66   matmul_fn(m, n, p, A, B, C, 0.0);
     67 
     68   for (size_t i = 0; i < m * p; i++) {
     69     int d = (int)E[i] - (int)C[i];
     70     if (d < 0) d = -d;
     71     if (d > (int)epsilon) {
     72       free(A);
     73       free(B);
     74       free(C);
     75       free(E);
     76       return MUNIT_FAIL;
     77     }
     78   }
     79 
     80   free(A);
     81   free(B);
     82   free(C);
     83   free(E);
     84   return MUNIT_OK;
     85 }
     86 
     87 static MunitResult test_u8_i8_u8(const char *name,
     88                                  int (*matmul_fn)(size_t, size_t, size_t, const uint8_t *, const int8_t *, uint8_t *,
     89                                                   double),
     90                                  double epsilon) {
     91   return test_u8_i8_u8_small(name, matmul_fn, epsilon);
     92 }
     93 
     94 static MunitResult test_scalar_u8_i8_u8(const MunitParameter *params, void *data) {
     95   (void)params;
     96   (void)data;
     97   return test_u8_i8_u8("scalar", matmul_scalar_u8_i8_u8, 0);
     98 }
     99 
    100 static MunitResult test_scalar_u8_i8_u8_medium(const MunitParameter *params, void *data) {
    101   (void)params;
    102   (void)data;
    103   return test_u8_i8_u8_medium("scalar", matmul_scalar_u8_i8_u8, 0);
    104 }
    105 
    106 #ifdef __AVX512VNNI__
    107 static MunitResult test_avx512vnni_u8_i8_u8(const MunitParameter *params, void *data) {
    108   (void)params;
    109   (void)data;
    110   return test_u8_i8_u8("avx512vnni", matmul_avx512vnni_u8_i8_u8, 0);
    111 }
    112 
    113 static MunitResult test_avx512vnni_u8_i8_u8_medium(const MunitParameter *params, void *data) {
    114   (void)params;
    115   (void)data;
    116   return test_u8_i8_u8_medium("avx512vnni", matmul_avx512vnni_u8_i8_u8, 0);
    117 }
    118 #endif
    119 
    120 static MunitResult test_dispatched_u8_i8_u8(const MunitParameter *params, void *data) {
    121   (void)params;
    122   (void)data;
    123   return test_u8_i8_u8("dispatched", matmul_u8_i8_u8, 0);
    124 }
    125 
    126 static MunitResult test_dispatched_u8_i8_u8_medium(const MunitParameter *params, void *data) {
    127   (void)params;
    128   (void)data;
    129   return test_u8_i8_u8_medium("dispatched", matmul_u8_i8_u8, 0);
    130 }
    131 
    132 static MunitResult test_u8_i8_u8_scaled_small(const char *name,
    133                                               int (*matmul_fn)(size_t, size_t, size_t, const uint8_t *, const int8_t *,
    134                                                                uint8_t *, double),
    135                                               double epsilon) {
    136   uint8_t A[] = {8, 16, 24, 32, 40, 48};
    137   int8_t  B[] = {2, 0, 0, 2, 0, 0};
    138   uint8_t C[4], E[4];
    139 
    140   ref_u8_i8_u8(2, 3, 2, A, B, E, 4.0);
    141   matmul_fn(2, 3, 2, A, B, C, 4.0);
    142 
    143   for (int i = 0; i < 4; i++) {
    144     int d = E[i] - C[i];
    145     if (d < 0) d = -d;
    146     if (d > (int)epsilon) return MUNIT_FAIL;
    147   }
    148   return MUNIT_OK;
    149 }
    150 
    151 static MunitResult test_u8_i8_u8_scaled_medium(const char *name,
    152                                                int (*matmul_fn)(size_t, size_t, size_t, const uint8_t *, const int8_t *,
    153                                                                 uint8_t *, double),
    154                                                double epsilon) {
    155   const size_t m = 64, n = 64, p = 64;
    156   uint8_t     *A = malloc(m * n);
    157   int8_t      *B = malloc(n * p);
    158   uint8_t     *C = malloc(m * p);
    159   uint8_t     *E = malloc(m * p);
    160   if (!A || !B || !C || !E) {
    161     free(A);
    162     free(B);
    163     free(C);
    164     free(E);
    165     return MUNIT_SKIP;
    166   }
    167 
    168   for (size_t i = 0; i < m * n; i++) A[i] = (uint8_t)((i * 7 + 13) % 251);
    169   for (size_t i = 0; i < n * p; i++) B[i] = (int8_t)(((i * 11 + 17) % 211) - 105);
    170   memset(C, 0, m * p);
    171   memset(E, 0, m * p);
    172 
    173   ref_u8_i8_u8(m, n, p, A, B, E, 8.0);
    174   matmul_fn(m, n, p, A, B, C, 8.0);
    175 
    176   for (size_t i = 0; i < m * p; i++) {
    177     int d = (int)E[i] - (int)C[i];
    178     if (d < 0) d = -d;
    179     if (d > (int)epsilon) {
    180       free(A);
    181       free(B);
    182       free(C);
    183       free(E);
    184       return MUNIT_FAIL;
    185     }
    186   }
    187 
    188   free(A);
    189   free(B);
    190   free(C);
    191   free(E);
    192   return MUNIT_OK;
    193 }
    194 
    195 static MunitResult test_scalar_u8_i8_u8_scaled(const MunitParameter *params, void *data) {
    196   (void)params;
    197   (void)data;
    198   return test_u8_i8_u8_scaled_small("scalar", matmul_scalar_u8_i8_u8, 0);
    199 }
    200 
    201 static MunitResult test_scalar_u8_i8_u8_scaled_medium(const MunitParameter *params, void *data) {
    202   (void)params;
    203   (void)data;
    204   return test_u8_i8_u8_scaled_medium("scalar", matmul_scalar_u8_i8_u8, 0);
    205 }
    206 
    207 #ifdef __AVX512VNNI__
    208 static MunitResult test_avx512vnni_u8_i8_u8_scaled(const MunitParameter *params, void *data) {
    209   (void)params;
    210   (void)data;
    211   return test_u8_i8_u8_scaled_small("avx512vnni", matmul_avx512vnni_u8_i8_u8, 0);
    212 }
    213 
    214 static MunitResult test_avx512vnni_u8_i8_u8_scaled_medium(const MunitParameter *params, void *data) {
    215   (void)params;
    216   (void)data;
    217   return test_u8_i8_u8_scaled_medium("avx512vnni", matmul_avx512vnni_u8_i8_u8, 0);
    218 }
    219 #endif
    220 
    221 static MunitResult test_dispatched_u8_i8_u8_scaled(const MunitParameter *params, void *data) {
    222   (void)params;
    223   (void)data;
    224   return test_u8_i8_u8_scaled_small("dispatched", matmul_u8_i8_u8, 0);
    225 }
    226 
    227 static MunitResult test_dispatched_u8_i8_u8_scaled_medium(const MunitParameter *params, void *data) {
    228   (void)params;
    229   (void)data;
    230   return test_u8_i8_u8_scaled_medium("dispatched", matmul_u8_i8_u8, 0);
    231 }
    232 
    233 /* ========================================================================== */
    234 /* f32_f32_f32 tests                                                          */
    235 /* ========================================================================== */
    236 
    237 static void ref_f32_f32_f32(size_t m, size_t n, size_t p, const float *A, const float *B, float *C, double scale) {
    238   for (size_t i = 0; i < m; i++)
    239     for (size_t j = 0; j < p; j++) {
    240       double sum = 0.0;
    241       for (size_t k = 0; k < n; k++) sum += (double)A[i * n + k] * (double)B[k * p + j];
    242       if (scale > 1.0) sum /= scale;
    243       C[i * p + j] = (float)sum;
    244     }
    245 }
    246 
    247 static MunitResult test_f32_f32_f32_small(const char *name,
    248                                           int (*matmul_fn)(size_t, size_t, size_t, const float *, const float *,
    249                                                            float *, double),
    250                                           double epsilon) {
    251   float A[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
    252   float B[] = {1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f};
    253   float C[4], E[4];
    254 
    255   ref_f32_f32_f32(2, 3, 2, A, B, E, 0.0);
    256   matmul_fn(2, 3, 2, A, B, C, 0.0);
    257 
    258   for (int i = 0; i < 4; i++) {
    259     double d = (double)E[i] - (double)C[i];
    260     if (d < 0) d = -d;
    261     if (d > epsilon) return MUNIT_FAIL;
    262   }
    263   return MUNIT_OK;
    264 }
    265 
    266 static MunitResult test_f32_f32_f32_medium(const char *name,
    267                                            int (*matmul_fn)(size_t, size_t, size_t, const float *, const float *,
    268                                                             float *, double),
    269                                            double epsilon) {
    270   const size_t m = 64, n = 64, p = 64;
    271   float       *A = malloc(m * n * sizeof(float));
    272   float       *B = malloc(n * p * sizeof(float));
    273   float       *C = malloc(m * p * sizeof(float));
    274   float       *E = malloc(m * p * sizeof(float));
    275   if (!A || !B || !C || !E) {
    276     free(A);
    277     free(B);
    278     free(C);
    279     free(E);
    280     return MUNIT_SKIP;
    281   }
    282 
    283   for (size_t i = 0; i < m * n; i++) A[i] = (float)((i * 7 + 13) % 251);
    284   for (size_t i = 0; i < n * p; i++) B[i] = (float)(((i * 11 + 17) % 211) - 105);
    285   memset(C, 0, m * p * sizeof(float));
    286   memset(E, 0, m * p * sizeof(float));
    287 
    288   ref_f32_f32_f32(m, n, p, A, B, E, 0.0);
    289   matmul_fn(m, n, p, A, B, C, 0.0);
    290 
    291   for (size_t i = 0; i < m * p; i++) {
    292     double d = (double)E[i] - (double)C[i];
    293     if (d < 0) d = -d;
    294     if (d > epsilon) {
    295       free(A);
    296       free(B);
    297       free(C);
    298       free(E);
    299       return MUNIT_FAIL;
    300     }
    301   }
    302 
    303   free(A);
    304   free(B);
    305   free(C);
    306   free(E);
    307   return MUNIT_OK;
    308 }
    309 
    310 static MunitResult test_f32_f32_f32_scaled_small(const char *name,
    311                                                  int (*matmul_fn)(size_t, size_t, size_t, const float *, const float *,
    312                                                                   float *, double),
    313                                                  double epsilon) {
    314   float A[] = {8.0f, 16.0f, 24.0f, 32.0f, 40.0f, 48.0f};
    315   float B[] = {2.0f, 0.0f, 0.0f, 2.0f, 0.0f, 0.0f};
    316   float C[4], E[4];
    317 
    318   ref_f32_f32_f32(2, 3, 2, A, B, E, 4.0);
    319   matmul_fn(2, 3, 2, A, B, C, 4.0);
    320 
    321   for (int i = 0; i < 4; i++) {
    322     double d = (double)E[i] - (double)C[i];
    323     if (d < 0) d = -d;
    324     if (d > epsilon) return MUNIT_FAIL;
    325   }
    326   return MUNIT_OK;
    327 }
    328 
    329 static MunitResult test_f32_f32_f32_scaled_medium(const char *name,
    330                                                   int (*matmul_fn)(size_t, size_t, size_t, const float *, const float *,
    331                                                                    float *, double),
    332                                                   double epsilon) {
    333   const size_t m = 64, n = 64, p = 64;
    334   float       *A = malloc(m * n * sizeof(float));
    335   float       *B = malloc(n * p * sizeof(float));
    336   float       *C = malloc(m * p * sizeof(float));
    337   float       *E = malloc(m * p * sizeof(float));
    338   if (!A || !B || !C || !E) {
    339     free(A);
    340     free(B);
    341     free(C);
    342     free(E);
    343     return MUNIT_SKIP;
    344   }
    345 
    346   for (size_t i = 0; i < m * n; i++) A[i] = (float)((i * 7 + 13) % 251);
    347   for (size_t i = 0; i < n * p; i++) B[i] = (float)(((i * 11 + 17) % 211) - 105);
    348   memset(C, 0, m * p * sizeof(float));
    349   memset(E, 0, m * p * sizeof(float));
    350 
    351   ref_f32_f32_f32(m, n, p, A, B, E, 8.0);
    352   matmul_fn(m, n, p, A, B, C, 8.0);
    353 
    354   for (size_t i = 0; i < m * p; i++) {
    355     double d = (double)E[i] - (double)C[i];
    356     if (d < 0) d = -d;
    357     if (d > epsilon) {
    358       free(A);
    359       free(B);
    360       free(C);
    361       free(E);
    362       return MUNIT_FAIL;
    363     }
    364   }
    365 
    366   free(A);
    367   free(B);
    368   free(C);
    369   free(E);
    370   return MUNIT_OK;
    371 }
    372 
    373 static MunitResult test_scalar_f32_f32_f32(const MunitParameter *params, void *data) {
    374   (void)params;
    375   (void)data;
    376   return test_f32_f32_f32_small("scalar", matmul_scalar_f32_f32_f32, 1e-5);
    377 }
    378 
    379 static MunitResult test_scalar_f32_f32_f32_medium(const MunitParameter *params, void *data) {
    380   (void)params;
    381   (void)data;
    382   return test_f32_f32_f32_medium("scalar", matmul_scalar_f32_f32_f32, 1e-3);
    383 }
    384 
    385 static MunitResult test_scalar_f32_f32_f32_scaled(const MunitParameter *params, void *data) {
    386   (void)params;
    387   (void)data;
    388   return test_f32_f32_f32_scaled_small("scalar", matmul_scalar_f32_f32_f32, 1e-5);
    389 }
    390 
    391 static MunitResult test_scalar_f32_f32_f32_scaled_medium(const MunitParameter *params, void *data) {
    392   (void)params;
    393   (void)data;
    394   return test_f32_f32_f32_scaled_medium("scalar", matmul_scalar_f32_f32_f32, 1e-3);
    395 }
    396 
    397 #ifdef __AVX2__
    398 static MunitResult test_avx2_f32_f32_f32(const MunitParameter *params, void *data) {
    399   (void)params;
    400   (void)data;
    401   return test_f32_f32_f32_small("avx2", matmul_avx2_f32_f32_f32, 1e-5);
    402 }
    403 
    404 static MunitResult test_avx2_f32_f32_f32_medium(const MunitParameter *params, void *data) {
    405   (void)params;
    406   (void)data;
    407   return test_f32_f32_f32_medium("avx2", matmul_avx2_f32_f32_f32, 1e-3);
    408 }
    409 
    410 static MunitResult test_avx2_f32_f32_f32_scaled(const MunitParameter *params, void *data) {
    411   (void)params;
    412   (void)data;
    413   return test_f32_f32_f32_scaled_small("avx2", matmul_avx2_f32_f32_f32, 1e-5);
    414 }
    415 
    416 static MunitResult test_avx2_f32_f32_f32_scaled_medium(const MunitParameter *params, void *data) {
    417   (void)params;
    418   (void)data;
    419   return test_f32_f32_f32_scaled_medium("avx2", matmul_avx2_f32_f32_f32, 1e-3);
    420 }
    421 #endif
    422 
    423 #ifdef __AVX512F__
    424 static MunitResult test_avx512_f32_f32_f32(const MunitParameter *params, void *data) {
    425   (void)params;
    426   (void)data;
    427   return test_f32_f32_f32_small("avx512", matmul_avx512_f32_f32_f32, 1e-5);
    428 }
    429 
    430 static MunitResult test_avx512_f32_f32_f32_medium(const MunitParameter *params, void *data) {
    431   (void)params;
    432   (void)data;
    433   return test_f32_f32_f32_medium("avx512", matmul_avx512_f32_f32_f32, 1e-3);
    434 }
    435 
    436 static MunitResult test_avx512_f32_f32_f32_scaled(const MunitParameter *params, void *data) {
    437   (void)params;
    438   (void)data;
    439   return test_f32_f32_f32_scaled_small("avx512", matmul_avx512_f32_f32_f32, 1e-5);
    440 }
    441 
    442 static MunitResult test_avx512_f32_f32_f32_scaled_medium(const MunitParameter *params, void *data) {
    443   (void)params;
    444   (void)data;
    445   return test_f32_f32_f32_scaled_medium("avx512", matmul_avx512_f32_f32_f32, 1e-3);
    446 }
    447 #endif
    448 
    449 static MunitResult test_dispatched_f32_f32_f32(const MunitParameter *params, void *data) {
    450   (void)params;
    451   (void)data;
    452   return test_f32_f32_f32_small("dispatched", matmul_f32_f32_f32, 1e-5);
    453 }
    454 
    455 static MunitResult test_dispatched_f32_f32_f32_medium(const MunitParameter *params, void *data) {
    456   (void)params;
    457   (void)data;
    458   return test_f32_f32_f32_medium("dispatched", matmul_f32_f32_f32, 1e-3);
    459 }
    460 
    461 static MunitResult test_dispatched_f32_f32_f32_scaled(const MunitParameter *params, void *data) {
    462   (void)params;
    463   (void)data;
    464   return test_f32_f32_f32_scaled_small("dispatched", matmul_f32_f32_f32, 1e-5);
    465 }
    466 
    467 static MunitResult test_dispatched_f32_f32_f32_scaled_medium(const MunitParameter *params, void *data) {
    468   (void)params;
    469   (void)data;
    470   return test_f32_f32_f32_scaled_medium("dispatched", matmul_f32_f32_f32, 1e-3);
    471 }
    472 
    473 /* ========================================================================== */
    474 /* f64_f64_f64 tests                                                          */
    475 /* ========================================================================== */
    476 
    477 static void ref_f64_f64_f64(size_t m, size_t n, size_t p, const double *A, const double *B, double *C, double scale) {
    478   for (size_t i = 0; i < m; i++)
    479     for (size_t j = 0; j < p; j++) {
    480       double sum = 0.0;
    481       for (size_t k = 0; k < n; k++) sum += A[i * n + k] * B[k * p + j];
    482       if (scale > 1.0) sum /= scale;
    483       C[i * p + j] = sum;
    484     }
    485 }
    486 
    487 static MunitResult test_f64_f64_f64_small(const char *name,
    488                                           int (*matmul_fn)(size_t, size_t, size_t, const double *, const double *,
    489                                                            double *, double),
    490                                           double epsilon) {
    491   double A[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
    492   double B[] = {1.0, 0.0, 0.0, 1.0, 0.0, 0.0};
    493   double C[4], E[4];
    494 
    495   ref_f64_f64_f64(2, 3, 2, A, B, E, 0.0);
    496   matmul_fn(2, 3, 2, A, B, C, 0.0);
    497 
    498   for (int i = 0; i < 4; i++) {
    499     double d = E[i] - C[i];
    500     if (d < 0) d = -d;
    501     if (d > epsilon) return MUNIT_FAIL;
    502   }
    503   return MUNIT_OK;
    504 }
    505 
    506 static MunitResult test_f64_f64_f64_medium(const char *name,
    507                                            int (*matmul_fn)(size_t, size_t, size_t, const double *, const double *,
    508                                                             double *, double),
    509                                            double epsilon) {
    510   const size_t m = 64, n = 64, p = 64;
    511   double      *A = malloc(m * n * sizeof(double));
    512   double      *B = malloc(n * p * sizeof(double));
    513   double      *C = malloc(m * p * sizeof(double));
    514   double      *E = malloc(m * p * sizeof(double));
    515   if (!A || !B || !C || !E) {
    516     free(A);
    517     free(B);
    518     free(C);
    519     free(E);
    520     return MUNIT_SKIP;
    521   }
    522 
    523   for (size_t i = 0; i < m * n; i++) A[i] = (double)((i * 7 + 13) % 251);
    524   for (size_t i = 0; i < n * p; i++) B[i] = (double)(((i * 11 + 17) % 211) - 105);
    525   memset(C, 0, m * p * sizeof(double));
    526   memset(E, 0, m * p * sizeof(double));
    527 
    528   ref_f64_f64_f64(m, n, p, A, B, E, 0.0);
    529   matmul_fn(m, n, p, A, B, C, 0.0);
    530 
    531   for (size_t i = 0; i < m * p; i++) {
    532     double d = E[i] - C[i];
    533     if (d < 0) d = -d;
    534     if (d > epsilon) {
    535       free(A);
    536       free(B);
    537       free(C);
    538       free(E);
    539       return MUNIT_FAIL;
    540     }
    541   }
    542 
    543   free(A);
    544   free(B);
    545   free(C);
    546   free(E);
    547   return MUNIT_OK;
    548 }
    549 
    550 static MunitResult test_f64_f64_f64_scaled_small(const char *name,
    551                                                  int (*matmul_fn)(size_t, size_t, size_t, const double *,
    552                                                                   const double *, double *, double),
    553                                                  double epsilon) {
    554   double A[] = {8.0, 16.0, 24.0, 32.0, 40.0, 48.0};
    555   double B[] = {2.0, 0.0, 0.0, 2.0, 0.0, 0.0};
    556   double C[4], E[4];
    557 
    558   ref_f64_f64_f64(2, 3, 2, A, B, E, 4.0);
    559   matmul_fn(2, 3, 2, A, B, C, 4.0);
    560 
    561   for (int i = 0; i < 4; i++) {
    562     double d = E[i] - C[i];
    563     if (d < 0) d = -d;
    564     if (d > epsilon) return MUNIT_FAIL;
    565   }
    566   return MUNIT_OK;
    567 }
    568 
    569 static MunitResult test_f64_f64_f64_scaled_medium(const char *name,
    570                                                   int (*matmul_fn)(size_t, size_t, size_t, const double *,
    571                                                                    const double *, double *, double),
    572                                                   double epsilon) {
    573   const size_t m = 64, n = 64, p = 64;
    574   double      *A = malloc(m * n * sizeof(double));
    575   double      *B = malloc(n * p * sizeof(double));
    576   double      *C = malloc(m * p * sizeof(double));
    577   double      *E = malloc(m * p * sizeof(double));
    578   if (!A || !B || !C || !E) {
    579     free(A);
    580     free(B);
    581     free(C);
    582     free(E);
    583     return MUNIT_SKIP;
    584   }
    585 
    586   for (size_t i = 0; i < m * n; i++) A[i] = (double)((i * 7 + 13) % 251);
    587   for (size_t i = 0; i < n * p; i++) B[i] = (double)(((i * 11 + 17) % 211) - 105);
    588   memset(C, 0, m * p * sizeof(double));
    589   memset(E, 0, m * p * sizeof(double));
    590 
    591   ref_f64_f64_f64(m, n, p, A, B, E, 8.0);
    592   matmul_fn(m, n, p, A, B, C, 8.0);
    593 
    594   for (size_t i = 0; i < m * p; i++) {
    595     double d = E[i] - C[i];
    596     if (d < 0) d = -d;
    597     if (d > epsilon) {
    598       free(A);
    599       free(B);
    600       free(C);
    601       free(E);
    602       return MUNIT_FAIL;
    603     }
    604   }
    605 
    606   free(A);
    607   free(B);
    608   free(C);
    609   free(E);
    610   return MUNIT_OK;
    611 }
    612 
    613 static MunitResult test_scalar_f64_f64_f64(const MunitParameter *params, void *data) {
    614   (void)params;
    615   (void)data;
    616   return test_f64_f64_f64_small("scalar", matmul_scalar_f64_f64_f64, 1e-12);
    617 }
    618 
    619 static MunitResult test_scalar_f64_f64_f64_medium(const MunitParameter *params, void *data) {
    620   (void)params;
    621   (void)data;
    622   return test_f64_f64_f64_medium("scalar", matmul_scalar_f64_f64_f64, 1e-9);
    623 }
    624 
    625 static MunitResult test_scalar_f64_f64_f64_scaled(const MunitParameter *params, void *data) {
    626   (void)params;
    627   (void)data;
    628   return test_f64_f64_f64_scaled_small("scalar", matmul_scalar_f64_f64_f64, 1e-12);
    629 }
    630 
    631 static MunitResult test_scalar_f64_f64_f64_scaled_medium(const MunitParameter *params, void *data) {
    632   (void)params;
    633   (void)data;
    634   return test_f64_f64_f64_scaled_medium("scalar", matmul_scalar_f64_f64_f64, 1e-9);
    635 }
    636 
    637 #ifdef __AVX2__
    638 static MunitResult test_avx2_f64_f64_f64(const MunitParameter *params, void *data) {
    639   (void)params;
    640   (void)data;
    641   return test_f64_f64_f64_small("avx2", matmul_avx2_f64_f64_f64, 1e-12);
    642 }
    643 
    644 static MunitResult test_avx2_f64_f64_f64_medium(const MunitParameter *params, void *data) {
    645   (void)params;
    646   (void)data;
    647   return test_f64_f64_f64_medium("avx2", matmul_avx2_f64_f64_f64, 1e-9);
    648 }
    649 
    650 static MunitResult test_avx2_f64_f64_f64_scaled(const MunitParameter *params, void *data) {
    651   (void)params;
    652   (void)data;
    653   return test_f64_f64_f64_scaled_small("avx2", matmul_avx2_f64_f64_f64, 1e-12);
    654 }
    655 
    656 static MunitResult test_avx2_f64_f64_f64_scaled_medium(const MunitParameter *params, void *data) {
    657   (void)params;
    658   (void)data;
    659   return test_f64_f64_f64_scaled_medium("avx2", matmul_avx2_f64_f64_f64, 1e-9);
    660 }
    661 #endif
    662 
    663 #ifdef __AVX512F__
    664 static MunitResult test_avx512_f64_f64_f64(const MunitParameter *params, void *data) {
    665   (void)params;
    666   (void)data;
    667   return test_f64_f64_f64_small("avx512", matmul_avx512_f64_f64_f64, 1e-12);
    668 }
    669 
    670 static MunitResult test_avx512_f64_f64_f64_medium(const MunitParameter *params, void *data) {
    671   (void)params;
    672   (void)data;
    673   return test_f64_f64_f64_medium("avx512", matmul_avx512_f64_f64_f64, 1e-9);
    674 }
    675 
    676 static MunitResult test_avx512_f64_f64_f64_scaled(const MunitParameter *params, void *data) {
    677   (void)params;
    678   (void)data;
    679   return test_f64_f64_f64_scaled_small("avx512", matmul_avx512_f64_f64_f64, 1e-12);
    680 }
    681 
    682 static MunitResult test_avx512_f64_f64_f64_scaled_medium(const MunitParameter *params, void *data) {
    683   (void)params;
    684   (void)data;
    685   return test_f64_f64_f64_scaled_medium("avx512", matmul_avx512_f64_f64_f64, 1e-9);
    686 }
    687 #endif
    688 
    689 static MunitResult test_dispatched_f64_f64_f64(const MunitParameter *params, void *data) {
    690   (void)params;
    691   (void)data;
    692   return test_f64_f64_f64_small("dispatched", matmul_f64_f64_f64, 1e-12);
    693 }
    694 
    695 static MunitResult test_dispatched_f64_f64_f64_medium(const MunitParameter *params, void *data) {
    696   (void)params;
    697   (void)data;
    698   return test_f64_f64_f64_medium("dispatched", matmul_f64_f64_f64, 1e-9);
    699 }
    700 
    701 static MunitResult test_dispatched_f64_f64_f64_scaled(const MunitParameter *params, void *data) {
    702   (void)params;
    703   (void)data;
    704   return test_f64_f64_f64_scaled_small("dispatched", matmul_f64_f64_f64, 1e-12);
    705 }
    706 
    707 static MunitResult test_dispatched_f64_f64_f64_scaled_medium(const MunitParameter *params, void *data) {
    708   (void)params;
    709   (void)data;
    710   return test_f64_f64_f64_scaled_medium("dispatched", matmul_f64_f64_f64, 1e-9);
    711 }
    712 
    713 /* ========================================================================== */
    714 /* matmul generic macro tests                                                 */
    715 /* ========================================================================== */
    716 
    717 static MunitResult test_matmul_u8_i8_u8(const char *name, const char *type) {
    718   const uint8_t A[] = {1, 2, 3, 4, 5, 6};
    719   const int8_t  B[] = {1, 0, 0, 1, 0, 0};
    720   uint8_t       C[4], E[4];
    721 
    722   ref_u8_i8_u8(2, 3, 2, A, B, E, 0.0);
    723   matmul(2, 3, 2, A, B, C, 0.0);
    724 
    725   for (int i = 0; i < 4; i++) {
    726     int d = E[i] - C[i];
    727     if (d < 0) d = -d;
    728     if (d > 0) return MUNIT_FAIL;
    729   }
    730   return MUNIT_OK;
    731 }
    732 
    733 static MunitResult test_matmul_u8_i8_u8_scaled(const char *name, const char *type) {
    734   const uint8_t A[] = {8, 16, 24, 32, 40, 48};
    735   const int8_t  B[] = {2, 0, 0, 2, 0, 0};
    736   uint8_t       C[4], E[4];
    737 
    738   ref_u8_i8_u8(2, 3, 2, A, B, E, 4.0);
    739   matmul(2, 3, 2, A, B, C, 4.0);
    740 
    741   for (int i = 0; i < 4; i++) {
    742     int d = E[i] - C[i];
    743     if (d < 0) d = -d;
    744     if (d > 0) return MUNIT_FAIL;
    745   }
    746   return MUNIT_OK;
    747 }
    748 
    749 static MunitResult test_matmul_u8_i8_u8_medium(const char *name, const char *type) {
    750   const size_t m = 64, n = 64, p = 64;
    751   uint8_t     *A = malloc(m * n);
    752   int8_t      *B = malloc(n * p);
    753   uint8_t     *C = malloc(m * p);
    754   uint8_t     *E = malloc(m * p);
    755   if (!A || !B || !C || !E) {
    756     free(A);
    757     free(B);
    758     free(C);
    759     free(E);
    760     return MUNIT_SKIP;
    761   }
    762 
    763   for (size_t i = 0; i < m * n; i++) A[i] = (uint8_t)((i * 7 + 13) % 251);
    764   for (size_t i = 0; i < n * p; i++) B[i] = (int8_t)(((i * 11 + 17) % 211) - 105);
    765   memset(C, 0, m * p);
    766   memset(E, 0, m * p);
    767 
    768   ref_u8_i8_u8(m, n, p, A, B, E, 0.0);
    769   matmul(m, n, p, A, B, C, 0.0);
    770 
    771   for (size_t i = 0; i < m * p; i++) {
    772     int d = (int)E[i] - (int)C[i];
    773     if (d < 0) d = -d;
    774     if (d > 0) {
    775       free(A);
    776       free(B);
    777       free(C);
    778       free(E);
    779       return MUNIT_FAIL;
    780     }
    781   }
    782 
    783   free(A);
    784   free(B);
    785   free(C);
    786   free(E);
    787   return MUNIT_OK;
    788 }
    789 
    790 static MunitResult test_matmul_f32_f32_f32(const char *name, const char *type) {
    791   const float A[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
    792   const float B[] = {1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f};
    793   float       C[4], E[4];
    794 
    795   ref_f32_f32_f32(2, 3, 2, A, B, E, 0.0);
    796   matmul(2, 3, 2, A, B, C, 0.0);
    797 
    798   for (int i = 0; i < 4; i++) {
    799     double d = (double)E[i] - (double)C[i];
    800     if (d < 0) d = -d;
    801     if (d > 1e-5) return MUNIT_FAIL;
    802   }
    803   return MUNIT_OK;
    804 }
    805 
    806 static MunitResult test_matmul_f32_f32_f32_scaled(const char *name, const char *type) {
    807   const float A[] = {8.0f, 16.0f, 24.0f, 32.0f, 40.0f, 48.0f};
    808   const float B[] = {2.0f, 0.0f, 0.0f, 2.0f, 0.0f, 0.0f};
    809   float       C[4], E[4];
    810 
    811   ref_f32_f32_f32(2, 3, 2, A, B, E, 4.0);
    812   matmul(2, 3, 2, A, B, C, 4.0);
    813 
    814   for (int i = 0; i < 4; i++) {
    815     double d = (double)E[i] - (double)C[i];
    816     if (d < 0) d = -d;
    817     if (d > 1e-5) return MUNIT_FAIL;
    818   }
    819   return MUNIT_OK;
    820 }
    821 
    822 static MunitResult test_matmul_f32_f32_f32_medium(const char *name, const char *type) {
    823   const size_t m = 64, n = 64, p = 64;
    824   float       *A = malloc(m * n * sizeof(float));
    825   float       *B = malloc(n * p * sizeof(float));
    826   float       *C = malloc(m * p * sizeof(float));
    827   float       *E = malloc(m * p * sizeof(float));
    828   if (!A || !B || !C || !E) {
    829     free(A);
    830     free(B);
    831     free(C);
    832     free(E);
    833     return MUNIT_SKIP;
    834   }
    835 
    836   for (size_t i = 0; i < m * n; i++) A[i] = (float)((i * 7 + 13) % 251);
    837   for (size_t i = 0; i < n * p; i++) B[i] = (float)(((i * 11 + 17) % 211) - 105);
    838   memset(C, 0, m * p * sizeof(float));
    839   memset(E, 0, m * p * sizeof(float));
    840 
    841   ref_f32_f32_f32(m, n, p, A, B, E, 0.0);
    842   matmul(m, n, p, A, B, C, 0.0);
    843 
    844   for (size_t i = 0; i < m * p; i++) {
    845     double d = (double)E[i] - (double)C[i];
    846     if (d < 0) d = -d;
    847     if (d > 1e-3) {
    848       free(A);
    849       free(B);
    850       free(C);
    851       free(E);
    852       return MUNIT_FAIL;
    853     }
    854   }
    855 
    856   free(A);
    857   free(B);
    858   free(C);
    859   free(E);
    860   return MUNIT_OK;
    861 }
    862 
    863 static MunitResult test_matmul_f64_f64_f64(const char *name, const char *type) {
    864   const double A[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
    865   const double B[] = {1.0, 0.0, 0.0, 1.0, 0.0, 0.0};
    866   double       C[4], E[4];
    867 
    868   ref_f64_f64_f64(2, 3, 2, A, B, E, 0.0);
    869   matmul(2, 3, 2, A, B, C, 0.0);
    870 
    871   for (int i = 0; i < 4; i++) {
    872     double d = E[i] - C[i];
    873     if (d < 0) d = -d;
    874     if (d > 1e-12) return MUNIT_FAIL;
    875   }
    876   return MUNIT_OK;
    877 }
    878 
    879 static MunitResult test_matmul_f64_f64_f64_scaled(const char *name, const char *type) {
    880   const double A[] = {8.0, 16.0, 24.0, 32.0, 40.0, 48.0};
    881   const double B[] = {2.0, 0.0, 0.0, 2.0, 0.0, 0.0};
    882   double       C[4], E[4];
    883 
    884   ref_f64_f64_f64(2, 3, 2, A, B, E, 4.0);
    885   matmul(2, 3, 2, A, B, C, 4.0);
    886 
    887   for (int i = 0; i < 4; i++) {
    888     double d = E[i] - C[i];
    889     if (d < 0) d = -d;
    890     if (d > 1e-12) return MUNIT_FAIL;
    891   }
    892   return MUNIT_OK;
    893 }
    894 
    895 static MunitResult test_matmul_f64_f64_f64_medium(const char *name, const char *type) {
    896   const size_t m = 64, n = 64, p = 64;
    897   double      *A = malloc(m * n * sizeof(double));
    898   double      *B = malloc(n * p * sizeof(double));
    899   double      *C = malloc(m * p * sizeof(double));
    900   double      *E = malloc(m * p * sizeof(double));
    901   if (!A || !B || !C || !E) {
    902     free(A);
    903     free(B);
    904     free(C);
    905     free(E);
    906     return MUNIT_SKIP;
    907   }
    908 
    909   for (size_t i = 0; i < m * n; i++) A[i] = (double)((i * 7 + 13) % 251);
    910   for (size_t i = 0; i < n * p; i++) B[i] = (double)(((i * 11 + 17) % 211) - 105);
    911   memset(C, 0, m * p * sizeof(double));
    912   memset(E, 0, m * p * sizeof(double));
    913 
    914   ref_f64_f64_f64(m, n, p, A, B, E, 0.0);
    915   matmul(m, n, p, A, B, C, 0.0);
    916 
    917   for (size_t i = 0; i < m * p; i++) {
    918     double d = E[i] - C[i];
    919     if (d < 0) d = -d;
    920     if (d > 1e-9) {
    921       free(A);
    922       free(B);
    923       free(C);
    924       free(E);
    925       return MUNIT_FAIL;
    926     }
    927   }
    928 
    929   free(A);
    930   free(B);
    931   free(C);
    932   free(E);
    933   return MUNIT_OK;
    934 }
    935 
    936 static MunitResult test_matmul_u8_i8_u8_const(const MunitParameter *params, void *data) {
    937   (void)params;
    938   (void)data;
    939   return test_matmul_u8_i8_u8("const", "u8_i8_u8");
    940 }
    941 
    942 static MunitResult test_matmul_u8_i8_u8_nonconst(const MunitParameter *params, void *data) {
    943   (void)params;
    944   (void)data;
    945   uint8_t A[] = {1, 2, 3, 4, 5, 6};
    946   int8_t  B[] = {1, 0, 0, 1, 0, 0};
    947   uint8_t C[4], E[4];
    948 
    949   ref_u8_i8_u8(2, 3, 2, A, B, E, 0.0);
    950   matmul(2, 3, 2, A, B, C, 0.0);
    951 
    952   for (int i = 0; i < 4; i++) {
    953     int d = E[i] - C[i];
    954     if (d < 0) d = -d;
    955     if (d > 0) return MUNIT_FAIL;
    956   }
    957   return MUNIT_OK;
    958 }
    959 
    960 static MunitResult test_matmul_u8_i8_u8_scaled_const(const MunitParameter *params, void *data) {
    961   (void)params;
    962   (void)data;
    963   return test_matmul_u8_i8_u8_scaled("const", "u8_i8_u8");
    964 }
    965 
    966 static MunitResult test_matmul_u8_i8_u8_medium_const(const MunitParameter *params, void *data) {
    967   (void)params;
    968   (void)data;
    969   return test_matmul_u8_i8_u8_medium("const", "u8_i8_u8");
    970 }
    971 
    972 static MunitResult test_matmul_f32_f32_f32_const(const MunitParameter *params, void *data) {
    973   (void)params;
    974   (void)data;
    975   return test_matmul_f32_f32_f32("const", "f32_f32_f32");
    976 }
    977 
    978 static int matmul_f32_f32_f32_nonconst(size_t m, size_t n, size_t p, float *A, float *B, float *C, double scale) {
    979   return matmul(m, n, p, A, B, C, scale);
    980 }
    981 
    982 static MunitResult test_matmul_f32_f32_f32_nonconst(const MunitParameter *params, void *data) {
    983   (void)params;
    984   (void)data;
    985   float A[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
    986   float B[] = {1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f};
    987   float C[4], E[4];
    988 
    989   ref_f32_f32_f32(2, 3, 2, A, B, E, 0.0);
    990   matmul_f32_f32_f32_nonconst(2, 3, 2, A, B, C, 0.0);
    991 
    992   for (int i = 0; i < 4; i++) {
    993     double d = (double)E[i] - (double)C[i];
    994     if (d < 0) d = -d;
    995     if (d > 1e-5) return MUNIT_FAIL;
    996   }
    997   return MUNIT_OK;
    998 }
    999 
   1000 static MunitResult test_matmul_f32_f32_f32_scaled_const(const MunitParameter *params, void *data) {
   1001   (void)params;
   1002   (void)data;
   1003   return test_matmul_f32_f32_f32_scaled("const", "f32_f32_f32");
   1004 }
   1005 
   1006 static MunitResult test_matmul_f32_f32_f32_medium_const(const MunitParameter *params, void *data) {
   1007   (void)params;
   1008   (void)data;
   1009   return test_matmul_f32_f32_f32_medium("const", "f32_f32_f32");
   1010 }
   1011 
   1012 static MunitResult test_matmul_f64_f64_f64_const(const MunitParameter *params, void *data) {
   1013   (void)params;
   1014   (void)data;
   1015   return test_matmul_f64_f64_f64("const", "f64_f64_f64");
   1016 }
   1017 
   1018 static int matmul_f64_f64_f64_nonconst(size_t m, size_t n, size_t p, double *A, double *B, double *C, double scale) {
   1019   return matmul(m, n, p, A, B, C, scale);
   1020 }
   1021 
   1022 static MunitResult test_matmul_f64_f64_f64_nonconst(const MunitParameter *params, void *data) {
   1023   (void)params;
   1024   (void)data;
   1025   double A[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
   1026   double B[] = {1.0, 0.0, 0.0, 1.0, 0.0, 0.0};
   1027   double C[4], E[4];
   1028 
   1029   ref_f64_f64_f64(2, 3, 2, A, B, E, 0.0);
   1030   matmul_f64_f64_f64_nonconst(2, 3, 2, A, B, C, 0.0);
   1031 
   1032   for (int i = 0; i < 4; i++) {
   1033     double d = E[i] - C[i];
   1034     if (d < 0) d = -d;
   1035     if (d > 1e-12) return MUNIT_FAIL;
   1036   }
   1037   return MUNIT_OK;
   1038 }
   1039 
   1040 static MunitResult test_matmul_f64_f64_f64_scaled_const(const MunitParameter *params, void *data) {
   1041   (void)params;
   1042   (void)data;
   1043   return test_matmul_f64_f64_f64_scaled("const", "f64_f64_f64");
   1044 }
   1045 
   1046 static MunitResult test_matmul_f64_f64_f64_medium_const(const MunitParameter *params, void *data) {
   1047   (void)params;
   1048   (void)data;
   1049   return test_matmul_f64_f64_f64_medium("const", "f64_f64_f64");
   1050 }
   1051 
   1052 static MunitTest tests[] = {
   1053     {"/scalar-u8-i8-u8", test_scalar_u8_i8_u8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
   1054     {"/scalar-u8-i8-u8-medium", test_scalar_u8_i8_u8_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
   1055     {"/scalar-u8-i8-u8-scaled", test_scalar_u8_i8_u8_scaled, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
   1056     {"/scalar-u8-i8-u8-scaled-medium", test_scalar_u8_i8_u8_scaled_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
   1057 #ifdef __AVX512VNNI__
   1058     {"/avx512vnni-u8-i8-u8", test_avx512vnni_u8_i8_u8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
   1059     {"/avx512vnni-u8-i8-u8-medium", test_avx512vnni_u8_i8_u8_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
   1060     {"/avx512vnni-u8-i8-u8-scaled", test_avx512vnni_u8_i8_u8_scaled, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
   1061     {"/avx512vnni-u8-i8-u8-scaled-medium", test_avx512vnni_u8_i8_u8_scaled_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE,
   1062      NULL},
   1063 #endif
   1064     {"/dispatched-u8-i8-u8", test_dispatched_u8_i8_u8, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
   1065     {"/dispatched-u8-i8-u8-medium", test_dispatched_u8_i8_u8_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
   1066     {"/dispatched-u8-i8-u8-scaled", test_dispatched_u8_i8_u8_scaled, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
   1067     {"/dispatched-u8-i8-u8-scaled-medium", test_dispatched_u8_i8_u8_scaled_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE,
   1068      NULL},
   1069     {"/scalar-f32-f32-f32", test_scalar_f32_f32_f32, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
   1070     {"/scalar-f32-f32-f32-medium", test_scalar_f32_f32_f32_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
   1071     {"/scalar-f32-f32-f32-scaled", test_scalar_f32_f32_f32_scaled, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
   1072     {"/scalar-f32-f32-f32-scaled-medium", test_scalar_f32_f32_f32_scaled_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE,
   1073      NULL},
   1074 #ifdef __AVX2__
   1075     {"/avx2-f32-f32-f32", test_avx2_f32_f32_f32, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
   1076     {"/avx2-f32-f32-f32-medium", test_avx2_f32_f32_f32_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
   1077     {"/avx2-f32-f32-f32-scaled", test_avx2_f32_f32_f32_scaled, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
   1078     {"/avx2-f32-f32-f32-scaled-medium", test_avx2_f32_f32_f32_scaled_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
   1079 #endif
   1080 #ifdef __AVX512F__
   1081     {"/avx512-f32-f32-f32", test_avx512_f32_f32_f32, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
   1082     {"/avx512-f32-f32-f32-medium", test_avx512_f32_f32_f32_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
   1083     {"/avx512-f32-f32-f32-scaled", test_avx512_f32_f32_f32_scaled, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
   1084     {"/avx512-f32-f32-f32-scaled-medium", test_avx512_f32_f32_f32_scaled_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE,
   1085      NULL},
   1086 #endif
   1087     {"/dispatched-f32-f32-f32", test_dispatched_f32_f32_f32, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
   1088     {"/dispatched-f32-f32-f32-medium", test_dispatched_f32_f32_f32_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
   1089     {"/dispatched-f32-f32-f32-scaled", test_dispatched_f32_f32_f32_scaled, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
   1090     {"/dispatched-f32-f32-f32-scaled-medium", test_dispatched_f32_f32_f32_scaled_medium, NULL, NULL,
   1091      MUNIT_TEST_OPTION_NONE, NULL},
   1092     {"/scalar-f64-f64-f64", test_scalar_f64_f64_f64, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
   1093     {"/scalar-f64-f64-f64-medium", test_scalar_f64_f64_f64_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
   1094     {"/scalar-f64-f64-f64-scaled", test_scalar_f64_f64_f64_scaled, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
   1095     {"/scalar-f64-f64-f64-scaled-medium", test_scalar_f64_f64_f64_scaled_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE,
   1096      NULL},
   1097 #ifdef __AVX2__
   1098     {"/avx2-f64-f64-f64", test_avx2_f64_f64_f64, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
   1099     {"/avx2-f64-f64-f64-medium", test_avx2_f64_f64_f64_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
   1100     {"/avx2-f64-f64-f64-scaled", test_avx2_f64_f64_f64_scaled, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
   1101     {"/avx2-f64-f64-f64-scaled-medium", test_avx2_f64_f64_f64_scaled_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
   1102 #endif
   1103 #ifdef __AVX512F__
   1104     {"/avx512-f64-f64-f64", test_avx512_f64_f64_f64, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
   1105     {"/avx512-f64-f64-f64-medium", test_avx512_f64_f64_f64_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
   1106     {"/avx512-f64-f64-f64-scaled", test_avx512_f64_f64_f64_scaled, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
   1107     {"/avx512-f64-f64-f64-scaled-medium", test_avx512_f64_f64_f64_scaled_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE,
   1108      NULL},
   1109 #endif
   1110     {"/dispatched-f64-f64-f64", test_dispatched_f64_f64_f64, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
   1111     {"/dispatched-f64-f64-f64-medium", test_dispatched_f64_f64_f64_medium, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
   1112     {"/dispatched-f64-f64-f64-scaled", test_dispatched_f64_f64_f64_scaled, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
   1113     {"/dispatched-f64-f64-f64-scaled-medium", test_dispatched_f64_f64_f64_scaled_medium, NULL, NULL,
   1114      MUNIT_TEST_OPTION_NONE, NULL},
   1115     {"/matmul-u8-i8-u8-const", test_matmul_u8_i8_u8_const, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
   1116     {"/matmul-u8-i8-u8-nonconst", test_matmul_u8_i8_u8_nonconst, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
   1117     {"/matmul-u8-i8-u8-scaled-const", test_matmul_u8_i8_u8_scaled_const, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
   1118     {"/matmul-u8-i8-u8-medium-const", test_matmul_u8_i8_u8_medium_const, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
   1119     {"/matmul-f32-f32-f32-const", test_matmul_f32_f32_f32_const, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
   1120     {"/matmul-f32-f32-f32-nonconst", test_matmul_f32_f32_f32_nonconst, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
   1121     {"/matmul-f32-f32-f32-scaled-const", test_matmul_f32_f32_f32_scaled_const, NULL, NULL, MUNIT_TEST_OPTION_NONE,
   1122      NULL},
   1123     {"/matmul-f32-f32-f32-medium-const", test_matmul_f32_f32_f32_medium_const, NULL, NULL, MUNIT_TEST_OPTION_NONE,
   1124      NULL},
   1125     {"/matmul-f64-f64-f64-const", test_matmul_f64_f64_f64_const, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
   1126     {"/matmul-f64-f64-f64-nonconst", test_matmul_f64_f64_f64_nonconst, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},
   1127     {"/matmul-f64-f64-f64-scaled-const", test_matmul_f64_f64_f64_scaled_const, NULL, NULL, MUNIT_TEST_OPTION_NONE,
   1128      NULL},
   1129     {"/matmul-f64-f64-f64-medium-const", test_matmul_f64_f64_f64_medium_const, NULL, NULL, MUNIT_TEST_OPTION_NONE,
   1130      NULL},
   1131     {NULL, NULL, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}};
   1132 
   1133 static const MunitSuite suite = {"/matmul", tests, NULL, 1, MUNIT_SUITE_OPTION_NONE};
   1134 
   1135 int main(int argc, char *argv[MUNIT_ARRAY_PARAM(argc)]) {
   1136   return munit_suite_main(&suite, NULL, argc, argv);
   1137 }