// Copyright 2025 RnD Center "ELVEES", JSC

#include "tests_cplx_func.hpp"

int main() {
  int failed_count = 0;
  int size[TEST_COUNT] = {64, 524, 1060, 2126, 8192};

  int print = 0;

  print_table_header();

#ifndef LOCAL_MEM
  void* src0 = memalign(64, 2 * size[TEST_COUNT - 1] * sizeof(int64_t));
  void* src1 = memalign(64, 2 * size[TEST_COUNT - 1] * sizeof(int64_t));
  void* dst_opt = memalign(64, 2 * size[TEST_COUNT - 1] * sizeof(int64_t));
  void* dst_ref = memalign(64, 2 * size[TEST_COUNT - 1] * sizeof(int64_t));
#else
#ifdef BARE_METAL
  void* src0 = &__local_mem;
#else
  disable_l2_cache();
  void* src0 = &xyram_data;
#endif
  void* src1 = static_cast<int8_t*>(src0) + 2 * size[TEST_COUNT - 1] * sizeof(int64_t);
  void* dst_opt = static_cast<int8_t*>(src1) + 2 * size[TEST_COUNT - 1] * sizeof(int64_t);
  void* dst_ref = static_cast<int8_t*>(dst_opt) + 2 * size[TEST_COUNT - 1] * sizeof(int64_t);
#endif

  std::string func_name = "| add_cplx32             |";
  for (int i = 0; i < TEST_COUNT; ++i) {
    int32_t input_bytes = size[i] * 2 * sizeof(int32_t) * 2;
    failed_count += test_cplx_func_2src_1dst(ref_add_cplx32, add_cplx32, (int32_t*)src0, (int32_t*)src1,
                                             (int32_t*)dst_ref, (int32_t*)dst_opt, size[i], input_bytes, func_name);
  }
  std::cout << "-------------------------------------------------------------------"
               "------------------------------------------------------"
            << std::endl;

  func_name = "| sub_cplx32             |";
  for (int i = 0; i < TEST_COUNT; ++i) {
    int32_t input_bytes = size[i] * 2 * sizeof(int32_t) * 2;
    failed_count += test_cplx_func_2src_1dst(ref_sub_cplx32, sub_cplx32, (int32_t*)src0, (int32_t*)src1,
                                             (int32_t*)dst_ref, (int32_t*)dst_opt, size[i], input_bytes, func_name);
  }
  std::cout << "-------------------------------------------------------------------"
               "------------------------------------------------------"
            << std::endl;

  func_name = "| mul_cplx32             |";
  for (int i = 0; i < TEST_COUNT; ++i) {
    int32_t input_bytes = size[i] * 2 * sizeof(int32_t) * 2;
    failed_count += test_cplx_func_2src_1dst(ref_mul_cplx32, mul_cplx32, (int32_t*)src0, (int32_t*)src1,
                                             (int32_t*)dst_ref, (int32_t*)dst_opt, size[i], input_bytes, func_name);
  }
  std::cout << "-------------------------------------------------------------------"
               "------------------------------------------------------"
            << std::endl;

  func_name = "| mul_cplx16_re_im       |";
  for (int i = 0; i < TEST_COUNT; ++i) {
    int32_t input_bytes = size[i] * 2 * sizeof(int16_t) * 2;
    failed_count += test_cplx_func_2src_1dst(ref_mul_cplx16_re_im, mul_cplx16_re_im, (int16_t*)src0, (int16_t*)src1,
                                             (int16_t*)dst_ref, (int16_t*)dst_opt, size[i], input_bytes, func_name);
  }
  std::cout << "-------------------------------------------------------------------"
               "------------------------------------------------------"
            << std::endl;

  func_name = "| mul_cplx32_re_im       |";
  for (int i = 0; i < TEST_COUNT; ++i) {
    int32_t input_bytes = size[i] * 2 * sizeof(int32_t) * 2;
    failed_count += test_cplx_func_2src_1dst(ref_mul_cplx32_re_im, mul_cplx32_re_im, (int32_t*)src0, (int32_t*)src1,
                                             (int32_t*)dst_ref, (int32_t*)dst_opt, size[i], input_bytes, func_name);
  }
  std::cout << "-------------------------------------------------------------------"
               "------------------------------------------------------"
            << std::endl;

  func_name = "| mul_cplx_fl_re_im      |";
  for (int i = 0; i < TEST_COUNT; ++i) {
    int32_t input_bytes = size[i] * 2 * sizeof(float) * 2;
    failed_count += test_cplx_func_2src_1dst(ref_mul_cplx_fl_re_im, mul_cplx_fl_re_im, (float*)src0, (float*)src1,
                                             (float*)dst_ref, (float*)dst_opt, size[i], input_bytes, func_name);
  }
  std::cout << "-------------------------------------------------------------------"
               "------------------------------------------------------"
            << std::endl;

  func_name = "| mul_cplx_db_re_im      |";
  for (int i = 0; i < TEST_COUNT; ++i) {
    int32_t input_bytes = size[i] * 2 * sizeof(double) * 2;
    failed_count += test_cplx_func_2src_1dst(ref_mul_cplx_db_re_im, mul_cplx_db_re_im, (double*)src0, (double*)src1,
                                             (double*)dst_ref, (double*)dst_opt, size[i], input_bytes, func_name);
  }
  std::cout << "-------------------------------------------------------------------"
               "------------------------------------------------------"
            << std::endl;

  func_name = "| mul_conj_cplx16_re_im  |";
  for (int i = 0; i < TEST_COUNT; ++i) {
    int32_t input_bytes = size[i] * 2 * sizeof(int16_t) * 2;
    failed_count +=
        test_cplx_func_2src_1dst(ref_mul_conj_cplx16_re_im, mul_conj_cplx16_re_im, (int16_t*)src0, (int16_t*)src1,
                                 (int16_t*)dst_ref, (int16_t*)dst_opt, size[i], input_bytes, func_name);
  }
  std::cout << "-------------------------------------------------------------------"
               "------------------------------------------------------"
            << std::endl;

  func_name = "| mul_conj_cplx32_re_im  |";
  for (int i = 0; i < TEST_COUNT; ++i) {
    int32_t input_bytes = size[i] * 2 * sizeof(int32_t) * 2;
    failed_count +=
        test_cplx_func_2src_1dst(ref_mul_conj_cplx32_re_im, mul_conj_cplx32_re_im, (int32_t*)src0, (int32_t*)src1,
                                 (int32_t*)dst_ref, (int32_t*)dst_opt, size[i], input_bytes, func_name);
  }
  std::cout << "-------------------------------------------------------------------"
               "------------------------------------------------------"
            << std::endl;

  func_name = "| mul_conj_cplx_fl_re_im |";
  for (int i = 0; i < TEST_COUNT; ++i) {
    int32_t input_bytes = size[i] * 2 * sizeof(float) * 2;
    failed_count +=
        test_cplx_func_2src_1dst(ref_mul_conj_cplx_fl_re_im, mul_conj_cplx_fl_re_im, (float*)src0, (float*)src1,
                                 (float*)dst_ref, (float*)dst_opt, size[i], input_bytes, func_name);
  }
  std::cout << "-------------------------------------------------------------------"
               "------------------------------------------------------"
            << std::endl;

  func_name = "| mul_conj_cplx_db_re_im |";
  for (int i = 0; i < TEST_COUNT; ++i) {
    int32_t input_bytes = size[i] * 2 * sizeof(double) * 2;
    failed_count +=
        test_cplx_func_2src_1dst(ref_mul_conj_cplx_db_re_im, mul_conj_cplx_db_re_im, (double*)src0, (double*)src1,
                                 (double*)dst_ref, (double*)dst_opt, size[i], input_bytes, func_name);
  }
  std::cout << "-------------------------------------------------------------------"
               "------------------------------------------------------"
            << std::endl;

  func_name = "| mul_conj_cplx32        |";
  for (int i = 0; i < TEST_COUNT; ++i) {
    int32_t input_bytes = size[i] * 2 * sizeof(int32_t) * 2;
    failed_count += test_cplx_func_2src_1dst(ref_mul_conj_cplx32, mul_conj_cplx32, (int32_t*)src0, (int32_t*)src1,
                                             (int32_t*)dst_ref, (int32_t*)dst_opt, size[i], input_bytes, func_name);
  }
  std::cout << "-------------------------------------------------------------------"
               "------------------------------------------------------"
            << std::endl;

  func_name = "| add_rconst_cplx32      |";
  for (int i = 0; i < TEST_COUNT; ++i) {
    int32_t input_bytes = size[i] * 2 * sizeof(int32_t) * 1;

    int32_t value = rand() % 100;

    failed_count += test_cplx_func_rconst(ref_add_rconst_cplx32, add_rconst_cplx32, value, (int32_t*)dst_ref,
                                          (int32_t*)dst_opt, size[i], input_bytes, func_name);
  }
  std::cout << "-------------------------------------------------------------------"
               "------------------------------------------------------"
            << std::endl;

  func_name = "| mul_rconst_cplx32      |";
  for (int i = 0; i < TEST_COUNT; ++i) {
    int32_t input_bytes = size[i] * 2 * sizeof(int32_t) * 1;

    int32_t value = rand() % 10;

    failed_count += test_cplx_func_rconst(ref_mul_rconst_cplx32, mul_rconst_cplx32, value, (int32_t*)dst_ref,
                                          (int32_t*)dst_opt, size[i], input_bytes, func_name);
  }
  std::cout << "-------------------------------------------------------------------"
               "------------------------------------------------------"
            << std::endl;

  func_name = "| scale_cplx32           |";
  for (int i = 0; i < TEST_COUNT; ++i) {
    int32_t input_bytes = size[i] * 2 * sizeof(int32_t) * 1;
    int32_t scale = 0;
    if (i % 2)
      scale = -scale;
    else
      scale = rand() % 10 - 5;

    failed_count += test_cplx_func_rconst(ref_scale_cplx32, scale_cplx32, scale, (int32_t*)dst_ref, (int32_t*)dst_opt,
                                          size[i], input_bytes, func_name);
  }
  std::cout << "-------------------------------------------------------------------"
               "------------------------------------------------------"
            << std::endl;

  func_name = "| div_rconst_cplx32      |";
  for (int i = 0; i < TEST_COUNT; ++i) {
    int32_t input_bytes = size[i] * 2 * sizeof(int32_t) * 1;

    int32_t coef = rand() % 10 - 5;
    if (coef == 0) coef = rand() % 10 + 1;

    failed_count += test_cplx_func_rconst(ref_div_rconst_cplx32, div_rconst_cplx32, coef, (int32_t*)dst_ref,
                                          (int32_t*)dst_opt, size[i], input_bytes, func_name);
  }
  std::cout << "-------------------------------------------------------------------"
               "------------------------------------------------------"
            << std::endl;

  create_array(static_cast<int32_t*>(src0), 2 * size[TEST_COUNT - 1], 0);
  create_array(static_cast<int32_t*>(src1), 2 * size[TEST_COUNT - 1], 0);

  memset(dst_ref, 0, sizeof(int32_t) * size[TEST_COUNT - 1] * 2);
  memset(dst_opt, 0, sizeof(int32_t) * size[TEST_COUNT - 1] * 2);
  for (int i = 0; i < TEST_COUNT; ++i) {
    int32_t input_bytes = size[i] * 2 * sizeof(int32_t) * 2;
    std::cout << "| vec_acc_cplx32         |";
    std::cout << " " << std::setw(14) << size[i] << " |";

    failed_count +=
        test_vec_acc_cplx32((int32_t*)src0, (int32_t*)dst_opt, (int32_t*)dst_ref, size[i], print, input_bytes);
  }

  for (int i = 0; i < size[TEST_COUNT - 1] * 2; ++i) {
    ((int32_t*)src0)[i] = rand() % 100 - 50;
    ((int32_t*)src1)[i] = rand() % 100 - 50;
  }
  std::cout << "-------------------------------------------------------------------"
               "------------------------------------------------------"
            << std::endl;

  for (int i = 0; i < size[TEST_COUNT - 1] * 2; ++i) {
    ((int32_t*)dst_ref)[i] = rand() % 10 - 5;
    ((int32_t*)dst_opt)[i] = ((int32_t*)dst_ref)[i];
  }

  for (int i = 0; i < TEST_COUNT; ++i) {
    int32_t input_bytes = size[i] * 2 * sizeof(int32_t) * 1;
    std::cout << "| vec_sum_cplx32         |";
    std::cout << " " << std::setw(14) << size[i] << " |";

    failed_count += test_vec_sum_cplx32((int32_t*)dst_opt, (int32_t*)dst_ref, size[i], print, input_bytes);
  }
  std::cout << "-------------------------------------------------------------------"
               "------------------------------------------------------"
            << std::endl;

  for (int i = 0; i < size[TEST_COUNT - 1] * 2; ++i) {
    ((int32_t*)src0)[i] = rand() % 100 - 50;
  }

  for (int i = 0; i < TEST_COUNT; ++i) {
    int32_t input_bytes = size[i] * 2 * sizeof(int32_t) * 1;
    std::cout << "| abs_cplx32             |";
    std::cout << " " << std::setw(14) << size[i] << " |";

    failed_count += test_abs_cplx32((int32_t*)src0, (int32_t*)dst_opt, (int32_t*)dst_ref, size[i], print, input_bytes);
  }
  std::cout << "-------------------------------------------------------------------"
               "------------------------------------------------------"
            << std::endl;

  memset(dst_ref, 0, sizeof(int32_t) * size[TEST_COUNT - 1] * 2);
  memset(dst_opt, 0, sizeof(int32_t) * size[TEST_COUNT - 1] * 2);

  for (int i = 0; i < size[TEST_COUNT - 1] * 2; ++i) {
    ((int32_t*)src0)[i] = rand() % 4 - 2;
  }

  for (int i = 0; i < size[TEST_COUNT - 1] * 2; ++i) {
    ((int32_t*)src1)[i] = rand() % 4 - 2;
  }

  for (int i = 0; i < TEST_COUNT; ++i) {
    int32_t input_bytes = size[i] * 2 * sizeof(int32_t) * 2;
    int size_ker = (rand() % size[i]) + 1;
    std::cout << "| conv_cplx, ker = " << std::setw(5) << size_ker << " |";
    std::cout << " " << std::setw(14) << size[i] << " |";

    failed_count += test_conv_cplx(ref_conv_cplx32, conv_cplx32, (int32_t*)src0, (int32_t*)src1, (int32_t*)dst_opt,
                                   (int32_t*)dst_ref, size[i], size_ker, print, input_bytes);
  }
  std::cout << "-------------------------------------------------------------------"
               "------------------------------------------------------"
            << std::endl;

  for (int i = 0; i < TEST_COUNT; ++i) {
    int32_t input_bytes = size[i] * 2 * sizeof(int32_t) * 2;
    int size_ker = (rand() % size[i]) + 1;
    std::cout << "| conv32_re_im, kr=" << std::setw(5) << size_ker << " |";
    std::cout << " " << std::setw(14) << size[i] << " |";

    failed_count += test_conv_cplx(ref_conv_cplx32_re_im, conv_cplx32_re_im, (int32_t*)src0, (int32_t*)src1,
                                   (int32_t*)dst_opt, (int32_t*)dst_ref, size[i], size_ker, print, input_bytes);
  }
  std::cout << "-------------------------------------------------------------------"
               "------------------------------------------------------"
            << std::endl;

  memset(dst_ref, 0, sizeof(float) * size[TEST_COUNT - 1] * 2);
  memset(dst_opt, 0, sizeof(float) * size[TEST_COUNT - 1] * 2);

  for (int i = 0; i < size[TEST_COUNT - 1] * 2; ++i) {
    ((float*)src0)[i] = (float)(rand() % 3 - 1);
  }

  for (int i = 0; i < size[TEST_COUNT - 1] * 2; ++i) {
    ((float*)src1)[i] = (float)(rand() % 3 - 1);
  }

  for (int i = 0; i < TEST_COUNT; ++i) {
    int32_t input_bytes = size[i] * 2 * sizeof(float) * 2;
    int size_ker = (rand() % size[i]) + 1;
    std::cout << "| conv_fl_re_im, k=" << std::setw(5) << size_ker << " |";
    std::cout << " " << std::setw(14) << size[i] << " |";

    failed_count += test_conv_cplx(ref_conv_cplx_fl_re_im, conv_cplx_fl_re_im, (float*)src0, (float*)src1,
                                   (float*)dst_opt, (float*)dst_ref, size[i], size_ker, print, input_bytes);
  }
  std::cout << "-------------------------------------------------------------------"
               "------------------------------------------------------"
            << std::endl;

#ifndef LOCAL_MEM
  free(src0);
  free(src1);
  free(dst_opt);
  free(dst_ref);
#else
#ifndef BARE_METAL
  enable_l2_cache(L2_CACHE_SIZE);
#endif
#endif

  return failed_count;
}

int test_vec_acc_cplx32(int32_t* src0, int32_t* dst_opt, int32_t* dst_ref, int size, int print, int32_t input_bytes) {
  int ret = 0;

  uint32_t tic_count[2], instruction_count[2];
  uint32_t ref_tic_count[2], ref_instruction_count[2];

  count_tics(ref_tic_count, ref_instruction_count);
  ref_acc_cplx32(dst_ref, src0, size);
  count_tics(&ref_tic_count[1], &ref_instruction_count[1]);

  count_tics(tic_count, instruction_count);
  acc_cplx32(dst_opt, src0, size);
  count_tics(&tic_count[1], &instruction_count[1]);

  ret = compare_arrays(dst_ref, dst_opt, size * 2);

  print_performance(ref_tic_count, ref_instruction_count, tic_count, instruction_count, input_bytes, 0);

  if (ret == 0)
    std::cout << " passed |\n";
  else
    std::cout << " failed |\n";

  return ret;
}

int test_vec_sum_cplx32(int32_t* dst_opt, int32_t* dst_ref, int size, int print, int32_t input_bytes) {
  int ret = 0;

  uint32_t tic_count[2], instruction_count[2];
  uint32_t ref_tic_count[2], ref_instruction_count[2];

  int32_t ref_res[2];
  count_tics(ref_tic_count, ref_instruction_count);
  ref_sum_cplx32(dst_ref, size, ref_res);
  count_tics(&ref_tic_count[1], &ref_instruction_count[1]);

  int32_t opt_res[2];
  count_tics(tic_count, instruction_count);
  sum_cplx32(dst_opt, size, opt_res);
  count_tics(&tic_count[1], &instruction_count[1]);

  ret = compare_arrays(ref_res, opt_res, 2);

  print_performance(ref_tic_count, ref_instruction_count, tic_count, instruction_count, input_bytes, 0);

  if (ret == 0)
    std::cout << " passed |\n";
  else
    std::cout << " failed |\n";

  return ret;
}

int test_abs_cplx32(int32_t* src0, int32_t* dst_opt, int32_t* dst_ref, int size, int print, int32_t input_bytes) {
  int ret = 0;

  uint32_t tic_count[2], instruction_count[2];
  uint32_t ref_tic_count[2], ref_instruction_count[2];

  count_tics(ref_tic_count, ref_instruction_count);
  ref_abs_cplx32(src0, size, dst_ref);
  count_tics(&ref_tic_count[1], &ref_instruction_count[1]);

  count_tics(tic_count, instruction_count);
  abs_cplx32(src0, size, dst_opt);
  count_tics(&tic_count[1], &instruction_count[1]);

  ret = compare_arrays_with_eps(dst_ref, dst_opt, size, 0.01);

  print_performance(ref_tic_count, ref_instruction_count, tic_count, instruction_count, input_bytes, 0);

  if (ret == 0)
    std::cout << " passed |\n";
  else
    std::cout << " failed |\n";

  return ret;
}

template <class T, class ref_func_ptr, class opt_func_ptr>
int test_conv_cplx(ref_func_ptr ref_func, opt_func_ptr opt_func, T* src0, T* src1, T* dst_opt, T* dst_ref, int size,
                   int size_ker, int print, int32_t input_bytes) {
  int ret = 0;

  uint32_t tic_count[2], instruction_count[2];
  uint32_t ref_tic_count[2], ref_instruction_count[2];

  count_tics(ref_tic_count, ref_instruction_count);
  ref_func(src0, src1, size, size_ker, dst_ref);
  count_tics(&ref_tic_count[1], &ref_instruction_count[1]);

  count_tics(tic_count, instruction_count);
  opt_func(src0, src1, size, size_ker, dst_opt);
  count_tics(&tic_count[1], &instruction_count[1]);

  ret = compare_arrays(dst_ref, dst_opt, size * 2);

  std::cout << std::setw(11) << ref_tic_count[1] - ref_tic_count[0] << " |";
#ifdef PRINT_INSTR
  std::cout << std::setw(11) << ref_instruction_count[1] - ref_instruction_count[0] << " |";
#endif
  std::cout << std::setw(11) << tic_count[1] - tic_count[0] << " |";
#ifdef PRINT_INSTR
  std::cout << std::setw(11) << instruction_count[1] - instruction_count[0] << " |";
#endif

  std::cout << "          - | ";
  std::cout.precision(3);
  std::cout << std::setw(3) << 1.0 * size_ker * (size - size_ker + 1) / (tic_count[1] - tic_count[0]) << " mul/tic |";

  std::cout << "            - |";
  if (ret == 0)
    std::cout << " passed |\n";
  else
    std::cout << " failed |\n";

  return ret;
}

template <class T, class ref_func_ptr, class opt_func_ptr>
int test_cplx_func_2src_1dst(ref_func_ptr ref_func, opt_func_ptr opt_func, T* src0, T* src1, T* dst_ref, T* dst_opt,
                             int size, int32_t input_bytes, std::string message) {
  std::cout << message;
  std::cout << " " << std::setw(14) << size << " |";

  int ret = 0;
  uint32_t tic_count[2], instruction_count[2];
  uint32_t ref_tic_count[2], ref_instruction_count[2];

  create_array(static_cast<T*>(src0), 2 * size, 0);
  create_array(static_cast<T*>(src1), 2 * size, 0);

  memset(dst_ref, 0, sizeof(T) * size * 2);
  memset(dst_opt, 0, sizeof(T) * size * 2);

  count_tics(ref_tic_count, ref_instruction_count);
  ref_func(src0, src1, size, dst_ref);
  count_tics(&ref_tic_count[1], &ref_instruction_count[1]);

  count_tics(tic_count, instruction_count);
  opt_func(src0, src1, size, dst_opt);
  count_tics(&tic_count[1], &instruction_count[1]);

  ret = compare_arrays(dst_ref, dst_opt, size * 2);

  print_performance(ref_tic_count, ref_instruction_count, tic_count, instruction_count, input_bytes, 0);

  if (ret == 0)
    std::cout << " passed |\n";
  else
    std::cout << " failed |\n";

  return ret;
}

template <class T, class ref_func_ptr, class opt_func_ptr>
int test_cplx_func_rconst(ref_func_ptr ref_func, opt_func_ptr opt_func, T value, T* dst_ref, T* dst_opt, int size,
                          int32_t input_bytes, std::string message) {
  std::cout << message;
  std::cout << " " << std::setw(14) << size << " |";

  int ret = 0;

  uint32_t tic_count[2], instruction_count[2];
  uint32_t ref_tic_count[2], ref_instruction_count[2];

  create_array(static_cast<T*>(dst_ref), 2 * size, 0);
  std::memcpy(dst_opt, dst_ref, 2 * size * sizeof(T));

  count_tics(ref_tic_count, ref_instruction_count);
  ref_func(dst_ref, size, value);
  count_tics(&ref_tic_count[1], &ref_instruction_count[1]);

  count_tics(tic_count, instruction_count);
  opt_func(dst_opt, size, value);
  count_tics(&tic_count[1], &instruction_count[1]);

  ret = compare_arrays(dst_ref, dst_opt, size * 2);

  print_performance(ref_tic_count, ref_instruction_count, tic_count, instruction_count, input_bytes, 0);

  if (ret == 0)
    std::cout << " passed |\n";
  else
    std::cout << " failed |\n";

  return ret;
}
