// Copyright 2025 RnD Center "ELVEES", JSC

#include "cpp_tests.hpp"

int main() {
  int failed_count = 0;
  int size[TEST_COUNT] = {64, 512, 1060, 2126, 8192};

  print_table_header();

#ifndef LOCAL_MEM
  void* src0 = memalign(64, size[TEST_COUNT - 1] * sizeof(int64_t));
  void* src1 = memalign(64, size[TEST_COUNT - 1] * sizeof(int64_t));
  void* dst_opt = memalign(64, size[TEST_COUNT - 1] * sizeof(int64_t));
  void* dst_ref = memalign(64, size[TEST_COUNT - 1] * sizeof(int64_t));
#else
#ifdef BARE_METAL
  void* src0 = &__local_mem;
#else
  disable_l2_cache();
  void* src0 = &xyram_data;
#endif
  void* src1 = static_cast<int8_t*>(src0) + size[TEST_COUNT - 1] * sizeof(int64_t);
  void* dst_opt = static_cast<int8_t*>(src1) + size[TEST_COUNT - 1] * sizeof(int64_t);
  void* dst_ref = static_cast<int8_t*>(dst_opt) + size[TEST_COUNT - 1] * sizeof(int64_t);
#endif

  std::string func_name = "| mul_16                 |";
  for (int i = 0; i < TEST_COUNT; ++i) {
    int32_t input_bytes = size[i] * sizeof(int16_t) * 2;
    failed_count += test_mul(ref_muls16, muls16, (int16_t*)src0, (int16_t*)src1, (int16_t*)dst_ref, (int16_t*)dst_opt,
                             size[i], input_bytes, func_name);
  }
  std::cout << "-------------------------------------------------------------------"
               "------------------------------------------------------"
            << std::endl;

  func_name = "| mul_32                 |";
  for (int i = 0; i < TEST_COUNT; ++i) {
    int32_t input_bytes = size[i] * sizeof(int32_t) * 2;
    failed_count += test_mul(ref_muls32, muls32, (int32_t*)src0, (int32_t*)src1, (int32_t*)dst_ref, (int32_t*)dst_opt,
                             size[i], input_bytes, func_name);
  }
  std::cout << "-------------------------------------------------------------------"
               "------------------------------------------------------"
            << std::endl;

  func_name = "| mul_fl                 |";
  for (int i = 0; i < TEST_COUNT; ++i) {
    int32_t input_bytes = size[i] * sizeof(float) * 2;
    failed_count += test_mul(ref_mul_fl, mul_fl, (float*)src0, (float*)src1, (float*)dst_ref, (float*)dst_opt, size[i],
                             input_bytes, func_name);
  }
  std::cout << "-------------------------------------------------------------------"
               "------------------------------------------------------"
            << std::endl;

  func_name = "| mul_db                 |";
  for (int i = 0; i < TEST_COUNT; ++i) {
    int32_t input_bytes = size[i] * sizeof(double) * 2;
    failed_count += test_mul(ref_mul_db, mul_db, (double*)src0, (double*)src1, (double*)dst_ref, (double*)dst_opt,
                             size[i], input_bytes, func_name);
  }
  std::cout << "-------------------------------------------------------------------"
               "------------------------------------------------------"
            << std::endl;

#ifndef LOCAL_MEM
  free(src0);
  free(src1);
  free(dst_opt);
  free(dst_ref);
#else
#ifndef BARE_METAL
  enable_l2_cache(L2_CACHE_SIZE);
#endif
#endif

  return failed_count;
}

template <class T, class ref_func_ptr, class opt_func_ptr>
int test_mul(ref_func_ptr ref_func, opt_func_ptr opt_func, T* src0, T* src1, T* dst_ref, T* dst_opt, int size,
             int32_t input_bytes, std::string message) {
  std::cout << message;
  std::cout << " " << std::setw(14) << size << " |";

  int ret = 0;
  uint32_t tic_count[2], instruction_count[2];
  uint32_t ref_tic_count[2], ref_instruction_count[2];

  create_array(static_cast<T*>(src0), size, 0);
  create_array(static_cast<T*>(src1), size, 0);

  memset(dst_ref, 0, sizeof(T) * size);
  memset(dst_opt, 0, sizeof(T) * size);

  count_tics(ref_tic_count, ref_instruction_count);
  ref_func(src0, src1, dst_ref, size);
  count_tics(&ref_tic_count[1], &ref_instruction_count[1]);

  count_tics(tic_count, instruction_count);
  opt_func(src0, src1, dst_opt, size);
  count_tics(&tic_count[1], &instruction_count[1]);

  ret = compare_arrays(dst_ref, dst_opt, size);

  print_performance(ref_tic_count, ref_instruction_count, tic_count, instruction_count, input_bytes, 0);

  if (ret == 0)
    std::cout << " passed |\n";
  else
    std::cout << " failed |\n";

  return ret;
}
