// Copyright 2025 RnD Center "ELVEES", JSC

#include "tests_tile_segmentation.hpp"

template <typename Type, class create_func, class ref_func, class run_calc_ptr, class cmp_func>
bool test_add(Type* src0, Type* src1, Type* dst_ref, Type* dst_opt, create_func create_vector, ref_func reference,
              run_calc_ptr run_calc, cmp_func comparator, int size, int* localmem) {
  create_vector(src0, size, 0);
  create_vector(src1, size, 0);

  FLUSH_ALL_CACHES();
  uint32_t tic_count[2], instruction_count[2];
  count_tics(tic_count, instruction_count);
  reference(src0, src1, dst_ref, size);
  count_tics(&tic_count[1], &instruction_count[1]);

  std::cout << "Ref func result (size = " << size << "): tic = " << tic_count[1] - tic_count[0]
            << " instr = " << instruction_count[1] - instruction_count[0] << std::endl;

  TileSegConfig config;
  CreateTileSegConfigAdd(src0, src1, dst_opt, size, &config, (uint16_t*)localmem);

  FLUSH_ALL_CACHES();

  count_tics(tic_count, instruction_count);
  run_calc(&config);
  count_tics(&tic_count[1], &instruction_count[1]);

  std::cout << "Opt func result (size = " << size << "): tic = " << tic_count[1] - tic_count[0]
            << " instr = " << instruction_count[1] - instruction_count[0] << std::endl;

  int ret = comparator(dst_ref, dst_opt, size);
  return ret;
}

int main() {
  disable_l2_cache();

  void* src0 = memalign(64, SIZE * sizeof(int64_t));
  void* src1 = memalign(64, SIZE * sizeof(int64_t));
  void* dst_ref = memalign(64, SIZE * sizeof(int64_t));
  void* dst_opt = memalign(64, SIZE * sizeof(int64_t));
  int ret = 0;
  int test_status = 0;

  for (int i = 1; i <= SIZE; i *= 2) {
    std::cout << "Add_s16" << std::endl;
#ifdef USE_REF_VER
    ret = test_add(static_cast<int16_t*>(src0), static_cast<int16_t*>(src1), static_cast<int16_t*>(dst_ref),
                   static_cast<int16_t*>(dst_opt), create_vector_s16, ref_adds16, RunCalculationAdd16, compare_s16, i,
                   &__local_mem);
#else
    ret = test_add(static_cast<int16_t*>(src0), static_cast<int16_t*>(src1), static_cast<int16_t*>(dst_ref),
                   static_cast<int16_t*>(dst_opt), create_vector_s16, adds16, RunCalculationAdd16, compare_s16, i,
                   &__local_mem);
#endif

    test_status |= ret;
    if (ret) std::cout << "add16 error!\n";

    std::cout << "Add_s32" << std::endl;
#ifdef USE_REF_VER
    ret = test_add(static_cast<int32_t*>(src0), static_cast<int32_t*>(src1), static_cast<int32_t*>(dst_ref),
                   static_cast<int32_t*>(dst_opt), create_vector_s32, ref_adds32, RunCalculationAdd32, compare_s32, i,
                   &__local_mem);
#else
    ret = test_add(static_cast<int32_t*>(src0), static_cast<int32_t*>(src1), static_cast<int32_t*>(dst_ref),
                   static_cast<int32_t*>(dst_opt), create_vector_s32, adds32, RunCalculationAdd32, compare_s32, i,
                   &__local_mem);
#endif

    test_status |= ret;
    if (ret) std::cout << "add32 error!\n";

    std::cout << "Add_fl" << std::endl;
#ifdef USE_REF_VER
    ret = test_add(static_cast<float*>(src0), static_cast<float*>(src1), static_cast<float*>(dst_ref),
                   static_cast<float*>(dst_opt), create_vector_float, ref_add_fl, RunCalculationAddFl, compare_float,
                   i, &__local_mem);
#else
    ret = test_add(static_cast<float*>(src0), static_cast<float*>(src1), static_cast<float*>(dst_ref),
                   static_cast<float*>(dst_opt), create_vector_float, add_fl, RunCalculationAddFl, compare_float, i,
                   &__local_mem);
#endif

    test_status |= ret;
    if (ret) std::cout << "add_fl error!\n";

    std::cout << "Add_db" << std::endl;
#ifdef USE_REF_VER
    ret = test_add(static_cast<double*>(src0), static_cast<double*>(src1), static_cast<double*>(dst_ref),
                   static_cast<double*>(dst_opt), create_vector_double, ref_add_db, RunCalculationAddDb,
                   compare_double, i, &__local_mem);
#else
    ret = test_add(static_cast<double*>(src0), static_cast<double*>(src1), static_cast<double*>(dst_ref),
                   static_cast<double*>(dst_opt), create_vector_double, add_db, RunCalculationAddDb, compare_double, i,
                   &__local_mem);
#endif

    test_status |= ret;
    if (ret) std::cout << "add_db error!\n";

    if (!test_status)
      std::cout << "Test passed" << std::endl;
    else
      std::cout << "Test failed" << std::endl;
  }

  free(src0);
  free(src1);
  free(dst_ref);
  free(dst_opt);

  enable_l2_cache(L2_CACHE_SIZE);

  return test_status;
}
