#include <math.h>

#include <elcore50-matrix-lib/mat_mul_with_dma_fl16.hpp>

extern "C" void float16_to_float32(uint16_t* src, float* dst, int size);

const int row0 = 128;
const int col0row1 = 528;
const int col1 = 128;

float fp32[row0 * col1];
uint16_t src0[row0 * col0row1];
uint16_t src1[col0row1 * col1];
uint16_t dst[row0 * col1];

void print_usage() {
  printf("Usage: test_matmul_perf_f16 [OPTIONS]\n");
  printf("Options description:\n");
  printf("\t-i    Number of iterations: unsigned long long. Default: 1\n");
  printf("\t      For endless run use 0\n");
  printf("\t-h    Print usage\n");
}

int main(int argc, char* argv[]) {
  if (argc > 3) {
    printf("Error: wrong number of parameters: %d!\n", argc);
    print_usage();
    return 1;
  }

  unsigned long long iterations = 1;
  if (argc > 1) {
    char* key = argv[1];
    if (!strcmp(key, "-h")) {
      print_usage();
      return 0;
    } else if (!strcmp(key, "-i")) {
      if (atoi(argv[2]) < 0) {
        printf("Error: wrong number of iterations: %s!\n", argv[2]);
        print_usage();
        return 1;
      }
      iterations = atoi(argv[2]);
    } else {
      printf("Unrecognized program option: %s\n", key);
      print_usage();
      return 1;
    }
  }

  for (size_t i = 0; i < row0; ++i)
    for (size_t j = 0; j < col0row1; ++j) {
      asm("fhcv %1, %0" : "=r"(src0[i * col0row1 + j]) : "r"(0.5 * sinf(i + j)));
      asm("fhcv %1, %0" : "=r"(src1[j * col1 + i]) : "r"(0.5 * sinf(i + j)));
    }

  float control_sum = 242.96855;

  int tics[6];
  int instr[6];

  unsigned long long it = 0;
  while (true) {
    mat_mul_fl16_16x128_fl16_general_no_ld_res(src0, row0, col0row1, src1, col1, dst, 0, 0, col1, 0, tics, instr);

    if (++it == iterations) break;

    if (it % 500 != 0) continue;

    float16_to_float32(dst, fp32, row0 * col1);

    float res_control_sum = 0.0f;
    for (size_t i = 0; i < col1 * row0; ++i) res_control_sum += fp32[i];

    if (fabs(res_control_sum - control_sum) > 1.0f) {
      printf("error %f != %f\n", res_control_sum, control_sum);
      return 1;
    }
  }

  printf("Test passed!\n");

  return 0;
}
