// Copyright 2018-2025 RnD Center "ELVEES", JSC

/*! \file
 *  \brief Тестирование функции dotp_sqr
 *  \author Фролов Андрей
 */
#include "tests.h"

int main() {
  int failed_count = 0;
  int size[TEST_COUNT] = {256, 512, 1024, 2048, 8192};
  int32_t G[TEST_COUNT] = {0, 50, 100, INT32_MAX, INT32_MIN};

  print_table_header();

#ifndef LOCAL_MEM
  void* src0_pos = memalign(64, size[TEST_COUNT - 1] * sizeof(int64_t));
  void* src1_pos = memalign(64, size[TEST_COUNT - 1] * sizeof(int64_t));
  void* src0_neg = memalign(64, size[TEST_COUNT - 1] * sizeof(int64_t));
  void* src1_neg = memalign(64, size[TEST_COUNT - 1] * sizeof(int64_t));
  void* src0_rnd = memalign(64, size[TEST_COUNT - 1] * sizeof(int64_t));
  void* src1_rnd = memalign(64, size[TEST_COUNT - 1] * sizeof(int64_t));
#else
#ifdef BARE_METAL
  void* src0_pos = &__local_mem;
#else
  disable_l2_cache();
  void* src0_pos = &xyram_data;
#endif

  void* src1_pos = src0_pos + size[TEST_COUNT - 1] * sizeof(int64_t);
  void* src0_neg = src1_pos + size[TEST_COUNT - 1] * sizeof(int64_t);
  void* src1_neg = src0_neg + size[TEST_COUNT - 1] * sizeof(int64_t);
  void* src0_rnd = src1_neg + size[TEST_COUNT - 1] * sizeof(int64_t);
  void* src1_rnd = src0_rnd + size[TEST_COUNT - 1] * sizeof(int64_t);
#endif

  int print = 0;

  create_vector_s16((int16_t*)src0_neg, size[TEST_COUNT - 1], -1);
  create_vector_s16((int16_t*)src1_neg, size[TEST_COUNT - 1], -1);
  create_vector_s16((int16_t*)src0_pos, size[TEST_COUNT - 1], 1);
  create_vector_s16((int16_t*)src1_pos, size[TEST_COUNT - 1], 1);
  create_vector_s16((int16_t*)src0_rnd, size[TEST_COUNT - 1], 0);
  create_vector_s16((int16_t*)src1_rnd, size[TEST_COUNT - 1], 0);
  for (int i = 0; i < TEST_COUNT; ++i) {
    int32_t input_bytes = size[i] * sizeof(int16_t) * 2;
    int32_t ti_tics = size[i] / 2 + 31;

#ifndef DISABLE_NEG_TEST_DATA
    printf("| dotp_sqr_short_neg     | %14d |", size[i]);
    failed_count += test_dotp_sqr((int16_t*)src0_neg, (int16_t*)src1_neg, size[i], G[i], print, input_bytes, ti_tics);
#endif

    printf("| dotp_sqr_short_rnd     | %14d |", size[i]);
    failed_count += test_dotp_sqr((int16_t*)src0_rnd, (int16_t*)src1_rnd, size[i], G[i], print, input_bytes, ti_tics);

#ifndef DISABLE_POS_TEST_DATA
    printf("| dotp_sqr_short_pos     | %14d |", size[i]);
    failed_count += test_dotp_sqr((int16_t*)src0_pos, (int16_t*)src1_pos, size[i], G[i], print, input_bytes, ti_tics);
#endif
  }

  create_vector_s32((int32_t*)src0_neg, size[TEST_COUNT - 1], -1);
  create_vector_s32((int32_t*)src1_neg, size[TEST_COUNT - 1], -1);
  create_vector_s32((int32_t*)src0_pos, size[TEST_COUNT - 1], 1);
  create_vector_s32((int32_t*)src1_pos, size[TEST_COUNT - 1], 1);
  create_vector_s32((int32_t*)src0_rnd, size[TEST_COUNT - 1], 0);
  create_vector_s32((int32_t*)src1_rnd, size[TEST_COUNT - 1], 0);
  for (int i = 0; i < TEST_COUNT; ++i) {
    int32_t input_bytes = size[i] * sizeof(int32_t) * 2;
    int32_t ti_tics = 0;

#ifndef DISABLE_NEG_TEST_DATA
    printf("| dotp_sqr_int_neg       | %14d |", size[i]);
    failed_count +=
        test_dotp_sqr32((int32_t*)src0_neg, (int32_t*)src1_neg, size[i], (int64_t)G[i], print, input_bytes, ti_tics);
#endif

    printf("| dotp_sqr_int_rnd       | %14d |", size[i]);
    failed_count +=
        test_dotp_sqr32((int32_t*)src0_rnd, (int32_t*)src1_rnd, size[i], (int64_t)G[i], print, input_bytes, ti_tics);

#ifndef DISABLE_POS_TEST_DATA
    printf("| dotp_sqr_int_pos       | %14d |", size[i]);
    failed_count +=
        test_dotp_sqr32((int32_t*)src0_pos, (int32_t*)src1_pos, size[i], (int64_t)G[i], print, input_bytes, ti_tics);
#endif
  }

  create_vector_float((float*)src0_neg, size[TEST_COUNT - 1], -1);
  create_vector_float((float*)src1_neg, size[TEST_COUNT - 1], -1);
  create_vector_float((float*)src0_pos, size[TEST_COUNT - 1], 1);
  create_vector_float((float*)src1_pos, size[TEST_COUNT - 1], 1);
  create_vector_float((float*)src0_rnd, size[TEST_COUNT - 1], 0);
  create_vector_float((float*)src1_rnd, size[TEST_COUNT - 1], 0);
  for (int i = 0; i < TEST_COUNT; ++i) {
    int32_t input_bytes = size[i] * sizeof(float) * 2;
    int32_t ti_tics = 0;

#ifndef DISABLE_NEG_TEST_DATA
    printf("| dotp_sqr_fl_neg        | %14d |", size[i]);
    failed_count +=
        test_dotp_sqr_fl((float*)src0_neg, (float*)src1_neg, size[i], (float)G[i], print, input_bytes, ti_tics);
#endif

    printf("| dotp_sqr_fl_rnd        | %14d |", size[i]);
    failed_count +=
        test_dotp_sqr_fl((float*)src0_rnd, (float*)src1_rnd, size[i], (float)G[i], print, input_bytes, ti_tics);

#ifndef DISABLE_POS_TEST_DATA
    printf("| dotp_sqr_fl_pos        | %14d |", size[i]);
    failed_count +=
        test_dotp_sqr_fl((float*)src0_pos, (float*)src1_pos, size[i], (float)G[i], print, input_bytes, ti_tics);
#endif
  }

  create_vector_double((double*)src0_neg, size[TEST_COUNT - 1], -1);
  create_vector_double((double*)src1_neg, size[TEST_COUNT - 1], -1);
  create_vector_double((double*)src0_pos, size[TEST_COUNT - 1], 1);
  create_vector_double((double*)src1_pos, size[TEST_COUNT - 1], 1);
  create_vector_double((double*)src0_rnd, size[TEST_COUNT - 1], 0);
  create_vector_double((double*)src1_rnd, size[TEST_COUNT - 1], 0);
  for (int i = 0; i < TEST_COUNT; ++i) {
    int32_t input_bytes = size[i] * sizeof(double) * 2;
    int32_t ti_tics = 0;

#ifndef DISABLE_NEG_TEST_DATA
    printf("| dotp_sqr_db_neg        | %14d |", size[i]);
    failed_count +=
        test_dotp_sqr_db((double*)src0_neg, (double*)src1_neg, size[i], (double)G[i], print, input_bytes, ti_tics);
#endif

    printf("| dotp_sqr_db_rnd        | %14d |", size[i]);
    failed_count +=
        test_dotp_sqr_db((double*)src0_rnd, (double*)src1_rnd, size[i], (double)G[i], print, input_bytes, ti_tics);

#ifndef DISABLE_POS_TEST_DATA
    printf("| dotp_sqr_db_pos        | %14d |", size[i]);
    failed_count +=
        test_dotp_sqr_db((double*)src0_pos, (double*)src1_pos, size[i], (double)G[i], print, input_bytes, ti_tics);
#endif
  }

#ifndef LOCAL_MEM
  free(src0_pos);
  free(src1_pos);
  free(src0_neg);
  free(src1_neg);
  free(src0_rnd);
  free(src1_rnd);
#else
#ifndef BARE_METAL
  enable_l2_cache(L2_CACHE_SIZE);
#endif
#endif

  return failed_count;
}

int test_dotp_sqr(int16_t* src0, int16_t* src1, int size, int32_t G, int print, int32_t input_bytes, int32_t ti_tics) {
  int ret = 0;

  uint32_t tic_count[2], instruction_count[2];
  uint32_t ref_tic_count[2], ref_instruction_count[2];

  int32_t res_et = G;
  int32_t res_dsp = G;

  count_tics(ref_tic_count, ref_instruction_count);
  int32_t res = ref_dotp_sqr(G, src0, src1, &res_et, size);
  count_tics(&ref_tic_count[1], &ref_instruction_count[1]);

  count_tics(tic_count, instruction_count);
  int32_t c = dotp_sqr(G, src0, src1, &res_dsp, size);
  count_tics(&tic_count[1], &instruction_count[1]);

  if (print) {
    printf("G:");
    printf("%d\n", G);
    printf("vect1:");
    print_vector_s16(src0, size);
    printf("vect2:");
    print_vector_s16(src1, size);
    printf("dsp_res:");
    printf("%d %d", c, res_dsp);
    printf("ref_res:");
    printf("%d %d", res, res_et);
  }

  if (c == res && res_et == res_dsp) {
    ret = 0;
  } else {
    ret = 1;
  }

  print_performance(ref_tic_count, ref_instruction_count, tic_count, instruction_count, input_bytes, ti_tics);

  if (ret == 0)
    printf(" passed |\n");
  else
    printf(" failed |\n");

  return ret;
}

int test_dotp_sqr32(int32_t* src0, int32_t* src1, int size, int64_t G, int print, int32_t input_bytes,
                    int32_t ti_tics) {
  int ret = 0;

  uint32_t tic_count[2], instruction_count[2];
  uint32_t ref_tic_count[2], ref_instruction_count[2];

  int64_t res_et = G;
  int64_t res_dsp = G;

  count_tics(ref_tic_count, ref_instruction_count);
  int64_t res = ref_dotp_sqr32(G, src0, src1, &res_et, size);
  count_tics(&ref_tic_count[1], &ref_instruction_count[1]);

  count_tics(tic_count, instruction_count);
  int64_t c = dotp_sqr32(G, src0, src1, &res_dsp, size);
  count_tics(&tic_count[1], &instruction_count[1]);

  if (print) {
    printf("G:");
    printf("%lld\n", (long long int)G);
    printf("vect1:");
    print_vector_s32(src0, size);
    printf("vect2:");
    print_vector_s32(src1, size);
    printf("dsp_res:");
    printf("%lld %lld", (long long int)c, (long long int)res_dsp);
    printf("ref_res:");
    printf("%lld %lld", (long long int)res, (long long int)res_et);
  }

  if (c == res && res_et == res_dsp) {
    ret = 0;
  } else {
    ret = 1;
  }

  print_performance(ref_tic_count, ref_instruction_count, tic_count, instruction_count, input_bytes, ti_tics);

  if (ret == 0)
    printf(" passed |\n");
  else
    printf(" failed |\n");

  return ret;
}

int test_dotp_sqr_fl(float* src0, float* src1, int size, float G, int print, int32_t input_bytes, int32_t ti_tics) {
  int ret = 0;

  uint32_t tic_count[2], instruction_count[2];
  uint32_t ref_tic_count[2], ref_instruction_count[2];

  float res_et[2];
  float res_dsp[2];

  res_et[0] = G;
  res_dsp[0] = G;

  count_tics(ref_tic_count, ref_instruction_count);
  res_et[1] = ref_dotp_sqr_fl(G, src0, src1, res_et, size);
  count_tics(&ref_tic_count[1], &ref_instruction_count[1]);

  count_tics(tic_count, instruction_count);
  res_dsp[1] = dotp_sqr_fl(G, src0, src1, res_dsp, size);
  count_tics(&tic_count[1], &instruction_count[1]);

  if (print) {
    printf("G:");
    printf("%f\n", G);
    printf("vect1:");
    print_vector_float(src0, size);
    printf("vect2:");
    print_vector_float(src1, size);
    printf("dsp_res:");
    printf("%f %f", res_dsp[1], *res_dsp);
    printf("ref_res:");
    printf("%f %f", res_et[1], *res_et);
  }

#ifdef EPS
  ret += compare_float_eps(res_et, res_dsp, 2, EPS);
#else
  ret += compare_float_eps(res_et, res_dsp, 2, 0.00001);
#endif

  print_performance(ref_tic_count, ref_instruction_count, tic_count, instruction_count, input_bytes, ti_tics);

  if (ret == 0)
    printf(" passed |\n");
  else
    printf(" failed |\n");

  return ret;
}

int test_dotp_sqr_db(double* src0, double* src1, int size, double G, int print, int32_t input_bytes, int32_t ti_tics) {
  int ret = 0;

  uint32_t tic_count[2], instruction_count[2];
  uint32_t ref_tic_count[2], ref_instruction_count[2];

  double res_et[2];
  double res_dsp[2];

  res_et[0] = G;
  res_dsp[0] = G;

  count_tics(ref_tic_count, ref_instruction_count);
  res_et[1] = ref_dotp_sqr_db(G, src0, src1, res_et, size);
  count_tics(&ref_tic_count[1], &ref_instruction_count[1]);

  count_tics(tic_count, instruction_count);
  res_dsp[1] = dotp_sqr_db(G, src0, src1, res_dsp, size);
  count_tics(&tic_count[1], &instruction_count[1]);

  if (print) {
    printf("G:");
    printf("%f\n", G);
    printf("vect1:");
    print_vector_double(src0, size);
    printf("vect2:");
    print_vector_double(src1, size);
    printf("dsp_res:");
    printf("%f %f", res_dsp[1], *res_dsp);
    printf("ref_res:");
    printf("%f %f", res_et[1], *res_et);
  }

  ret += compare_double_eps(res_et, res_dsp, 2, 0.000001);

  print_performance(ref_tic_count, ref_instruction_count, tic_count, instruction_count, input_bytes, ti_tics);

  if (ret == 0)
    printf(" passed |\n");
  else
    printf(" failed |\n");

  return ret;
}
