// Copyright 2018-2025 RnD Center "ELVEES", JSC

/*! \file
 *  \brief Тестирование функции mat_mul_cplx
 *  \author Фролов Андрей
 */

#include "tests.h"

int main() {
  int failed_count = 0;
  int rows[TEST_COUNT] = {4, 8, 16, 32, 64};
  int columns[TEST_COUNT] = {4, 8, 16, 32, 64};
  int columns1[TEST_COUNT] = {4, 8, 16, 32, 64};
  int32_t shift[TEST_COUNT] = {0, 1, 5, 10, 15};
  int print = 0;

  print_table_header();

#ifndef LOCAL_MEM
  void* src0_pos = memalign(64, 2 * rows[TEST_COUNT - 1] * columns[TEST_COUNT - 1] * sizeof(int64_t));
  void* src1_pos = memalign(64, 2 * columns[TEST_COUNT - 1] * columns1[TEST_COUNT - 1] * sizeof(int64_t));
  void* src0_neg = memalign(64, 2 * rows[TEST_COUNT - 1] * columns[TEST_COUNT - 1] * sizeof(int64_t));
  void* src1_neg = memalign(64, 2 * columns[TEST_COUNT - 1] * columns1[TEST_COUNT - 1] * sizeof(int64_t));
  void* src0_rnd = memalign(64, 2 * rows[TEST_COUNT - 1] * columns[TEST_COUNT - 1] * sizeof(int64_t));
  void* src1_rnd = memalign(64, 2 * columns[TEST_COUNT - 1] * columns1[TEST_COUNT - 1] * sizeof(int64_t));
  void* dst_ref = memalign(64, 2 * rows[TEST_COUNT - 1] * columns1[TEST_COUNT - 1] * sizeof(int64_t));
  void* dst_opt = memalign(64, 2 * rows[TEST_COUNT - 1] * columns1[TEST_COUNT - 1] * sizeof(int64_t));
#else
#ifdef BARE_METAL
  void* src0_pos = &__local_mem;
#else
  disable_l2_cache();
  void* src0_pos = &xyram_data;
#endif
  void* src1_pos = src0_pos + 2 * rows[TEST_COUNT - 1] * columns[TEST_COUNT - 1] * sizeof(int64_t);
  void* src0_neg = src1_pos + 2 * columns[TEST_COUNT - 1] * columns1[TEST_COUNT - 1] * sizeof(int64_t);
  void* src1_neg = src0_neg + 2 * rows[TEST_COUNT - 1] * columns[TEST_COUNT - 1] * sizeof(int64_t);
  void* src0_rnd = src1_neg + 2 * columns[TEST_COUNT - 1] * columns1[TEST_COUNT - 1] * sizeof(int64_t);
  void* src1_rnd = src0_rnd + 2 * rows[TEST_COUNT - 1] * columns[TEST_COUNT - 1] * sizeof(int64_t);
  void* dst_ref = src1_rnd + 2 * columns[TEST_COUNT - 1] * columns1[TEST_COUNT - 1] * sizeof(int64_t);
  void* dst_opt = dst_ref + 2 * rows[TEST_COUNT - 1] * columns1[TEST_COUNT - 1] * sizeof(int64_t);
#endif

  create_vector_s16((int16_t*)src0_neg, 2 * rows[TEST_COUNT - 1] * columns[TEST_COUNT - 1], -1);
  create_vector_s16((int16_t*)src1_neg, 2 * columns[TEST_COUNT - 1] * columns1[TEST_COUNT - 1], -1);
  create_vector_s16((int16_t*)src0_pos, 2 * rows[TEST_COUNT - 1] * columns[TEST_COUNT - 1], 1);
  create_vector_s16((int16_t*)src1_pos, 2 * columns[TEST_COUNT - 1] * columns1[TEST_COUNT - 1], 1);
  create_vector_s16((int16_t*)src0_rnd, 2 * rows[TEST_COUNT - 1] * columns[TEST_COUNT - 1], 0);
  create_vector_s16((int16_t*)src1_rnd, 2 * columns[TEST_COUNT - 1] * columns1[TEST_COUNT - 1], 0);

  for (int i = 0; i < TEST_COUNT; ++i) {
    int _r1 = 2 * ceil_(rows[i] / ((double)2.0));
    int _c1 = 2 * ceil_(columns[i] / ((double)2.0));
    int _c2 = 4 * ceil_(columns1[i] / ((double)4.0));

    int32_t input_bytes = columns[i] * (rows[i] + columns1[i]) * sizeof(int16_t) * 2;
    int32_t ti_tics = (int)(0.25 * (_r1 * _c2 * _c1) + 6.75 * (_r1 * _c2) + 6 * _c2 + 36);

#ifndef DISABLE_NEG_TEST_DATA
    printf("| mat_mul_cplx_short_neg | %4d %4d %4d |", rows[i], columns[i], columns1[i]);
    failed_count += test_mat_mul_cplx((int16_t*)src0_neg, (int16_t*)src1_neg, (int16_t*)dst_ref, (int16_t*)dst_opt,
                                      rows[i], columns[i], columns1[i], print, shift[i], input_bytes, ti_tics);
#endif

    printf("| mat_mul_cplx_short_rnd | %4d %4d %4d |", rows[i], columns[i], columns1[i]);
    failed_count += test_mat_mul_cplx((int16_t*)src0_rnd, (int16_t*)src1_rnd, (int16_t*)dst_ref, (int16_t*)dst_opt,
                                      rows[i], columns[i], columns1[i], print, shift[i], input_bytes, ti_tics);

#ifndef DISABLE_POS_TEST_DATA
    printf("| mat_mul_cplx_short_pos | %4d %4d %4d |", rows[i], columns[i], columns1[i]);
    failed_count += test_mat_mul_cplx((int16_t*)src0_pos, (int16_t*)src1_pos, (int16_t*)dst_ref, (int16_t*)dst_opt,
                                      rows[i], columns[i], columns1[i], print, shift[i], input_bytes, ti_tics);
#endif
  }

  create_vector_s32((int32_t*)src0_neg, 2 * rows[TEST_COUNT - 1] * columns[TEST_COUNT - 1], -1);
  create_vector_s32((int32_t*)src1_neg, 2 * columns[TEST_COUNT - 1] * columns1[TEST_COUNT - 1], -1);
  create_vector_s32((int32_t*)src0_pos, 2 * rows[TEST_COUNT - 1] * columns[TEST_COUNT - 1], 1);
  create_vector_s32((int32_t*)src1_pos, 2 * columns[TEST_COUNT - 1] * columns1[TEST_COUNT - 1], 1);
  create_vector_s32((int32_t*)src0_rnd, 2 * rows[TEST_COUNT - 1] * columns[TEST_COUNT - 1], 0);
  create_vector_s32((int32_t*)src1_rnd, 2 * columns[TEST_COUNT - 1] * columns1[TEST_COUNT - 1], 0);
  for (int i = 0; i < TEST_COUNT; ++i) {
    int32_t input_bytes = columns[i] * (rows[i] + columns1[i]) * sizeof(int32_t) * 2;
    int32_t ti_tics = 0;

#ifndef DISABLE_NEG_TEST_DATA
    printf("| mat_mul_cplx_int_neg   | %4d %4d %4d |", rows[i], columns[i], columns1[i]);
    failed_count += test_mat_mul_cplx32((int32_t*)src0_neg, (int32_t*)src1_neg, (int32_t*)dst_ref, (int32_t*)dst_opt,
                                        rows[i], columns[i], columns1[i], print, shift[i], input_bytes, ti_tics);
#endif

    printf("| mat_mul_cplx_int_rnd   | %4d %4d %4d |", rows[i], columns[i], columns1[i]);
    failed_count += test_mat_mul_cplx32((int32_t*)src0_rnd, (int32_t*)src1_rnd, (int32_t*)dst_ref, (int32_t*)dst_opt,
                                        rows[i], columns[i], columns1[i], print, shift[i], input_bytes, ti_tics);

#ifndef DISABLE_POS_TEST_DATA
    printf("| mat_mul_cplx_int_pos   | %4d %4d %4d |", rows[i], columns[i], columns1[i]);
    failed_count += test_mat_mul_cplx32((int32_t*)src0_pos, (int32_t*)src1_pos, (int32_t*)dst_ref, (int32_t*)dst_opt,
                                        rows[i], columns[i], columns1[i], print, shift[i], input_bytes, ti_tics);
#endif
  }

  create_vector_float((float*)src0_neg, 2 * rows[TEST_COUNT - 1] * columns[TEST_COUNT - 1], -1);
  create_vector_float((float*)src1_neg, 2 * columns[TEST_COUNT - 1] * columns1[TEST_COUNT - 1], -1);
  create_vector_float((float*)src0_pos, 2 * rows[TEST_COUNT - 1] * columns[TEST_COUNT - 1], 1);
  create_vector_float((float*)src1_pos, 2 * columns[TEST_COUNT - 1] * columns1[TEST_COUNT - 1], 1);
  create_vector_float((float*)src0_rnd, 2 * rows[TEST_COUNT - 1] * columns[TEST_COUNT - 1], 0);
  create_vector_float((float*)src1_rnd, 2 * columns[TEST_COUNT - 1] * columns1[TEST_COUNT - 1], 0);
  for (int i = 0; i < TEST_COUNT; ++i) {
    int r1 = rows[i];
    int c1 = columns[i];
    int c2 = columns1[i];

    int32_t input_bytes = columns[i] * (rows[i] + columns1[i]) * sizeof(float) * 2;
    int32_t ti_tics = (int)(5.0 / 8 * r1 * c2 * c1 + 58.0 / 8 * r1 * c2 + 100.0 / 8 * r1 + 30);

#ifndef DISABLE_NEG_TEST_DATA
    printf("| mat_mul_cplx_fl_neg    | %4d %4d %4d |", rows[i], columns[i], columns1[i]);
    failed_count += test_mat_mul_cplx_fl((float*)src0_neg, (float*)src1_neg, (float*)dst_ref, (float*)dst_opt, rows[i],
                                         columns[i], columns1[i], print, shift[i], input_bytes, ti_tics);
#endif

    printf("| mat_mul_cplx_fl_rnd    | %4d %4d %4d |", rows[i], columns[i], columns1[i]);
    failed_count += test_mat_mul_cplx_fl((float*)src0_rnd, (float*)src1_rnd, (float*)dst_ref, (float*)dst_opt, rows[i],
                                         columns[i], columns1[i], print, shift[i], input_bytes, ti_tics);

#ifndef DISABLE_POS_TEST_DATA
    printf("| mat_mul_cplx_fl_pos    | %4d %4d %4d |", rows[i], columns[i], columns1[i]);
    failed_count += test_mat_mul_cplx_fl((float*)src0_pos, (float*)src1_pos, (float*)dst_ref, (float*)dst_opt, rows[i],
                                         columns[i], columns1[i], print, shift[i], input_bytes, ti_tics);
#endif
  }

  create_vector_double((double*)src0_neg, 2 * rows[TEST_COUNT - 1] * columns[TEST_COUNT - 1], -1);
  create_vector_double((double*)src1_neg, 2 * columns[TEST_COUNT - 1] * columns1[TEST_COUNT - 1], -1);
  create_vector_double((double*)src0_pos, 2 * rows[TEST_COUNT - 1] * columns[TEST_COUNT - 1], 1);
  create_vector_double((double*)src1_pos, 2 * columns[TEST_COUNT - 1] * columns1[TEST_COUNT - 1], 1);
  create_vector_double((double*)src0_rnd, 2 * rows[TEST_COUNT - 1] * columns[TEST_COUNT - 1], 0);
  create_vector_double((double*)src1_rnd, 2 * columns[TEST_COUNT - 1] * columns1[TEST_COUNT - 1], 0);
  for (int i = 0; i < TEST_COUNT; ++i) {
    int32_t input_bytes = columns[i] * (rows[i] + columns1[i]) * sizeof(double) * 2;
    int32_t ti_tics = 0;

#ifndef DISABLE_NEG_TEST_DATA
    printf("| mat_mul_cplx_db_neg    | %4d %4d %4d |", rows[i], columns[i], columns1[i]);
    failed_count += test_mat_mul_cplx_db((double*)src0_neg, (double*)src1_neg, (double*)dst_ref, (double*)dst_opt,
                                         rows[i], columns[i], columns1[i], print, shift[i], input_bytes, ti_tics);
#endif

    printf("| mat_mul_cplx_db_rnd    | %4d %4d %4d |", rows[i], columns[i], columns1[i]);
    failed_count += test_mat_mul_cplx_db((double*)src0_rnd, (double*)src1_rnd, (double*)dst_ref, (double*)dst_opt,
                                         rows[i], columns[i], columns1[i], print, shift[i], input_bytes, ti_tics);

#ifndef DISABLE_POS_TEST_DATA
    printf("| mat_mul_cplx_db_pos    | %4d %4d %4d |", rows[i], columns[i], columns1[i]);
    failed_count += test_mat_mul_cplx_db((double*)src0_pos, (double*)src1_pos, (double*)dst_ref, (double*)dst_opt,
                                         rows[i], columns[i], columns1[i], print, shift[i], input_bytes, ti_tics);
#endif
  }

#ifndef LOCAL_MEM
  free(src0_neg);
  free(src1_neg);
  free(src0_pos);
  free(src1_pos);
  free(src0_rnd);
  free(src1_rnd);

  free(dst_ref);
  free(dst_opt);
#else
#ifndef BARE_METAL
  enable_l2_cache(L2_CACHE_SIZE);
#endif
#endif

  return failed_count;
}

int test_mat_mul_cplx(int16_t* src0, int16_t* src1, int16_t* dst_ref, int16_t* dst_opt, int rows0, int columns0,
                      int columns1, int print, int32_t shift, int32_t input_bytes, int32_t ti_tics) {
  int ret = 0;

  uint32_t tic_count[2], instruction_count[2];
  uint32_t ref_tic_count[2], ref_instruction_count[2];

  count_tics(ref_tic_count, ref_instruction_count);
  ref_mat_mul_cplx(src0, rows0, columns0, src1, columns1, dst_ref, shift);
  count_tics(&ref_tic_count[1], &ref_instruction_count[1]);

  count_tics(tic_count, instruction_count);
  mat_mul_cplx(src0, rows0, columns0, src1, columns1, dst_opt, shift);
  count_tics(&tic_count[1], &instruction_count[1]);

  if (print) {
    printf("mat1 %d X %d\n:", rows0, columns0);
    print_matrix_s16(src0, rows0, 2 * columns0);
    printf("mat2 %d X %d\n:", columns0, columns1);
    print_matrix_s16(src1, columns0, 2 * columns1);

    printf("ref_res %d X %d\n:", rows0, columns1);
    print_matrix_s16(dst_ref, rows0, 2 * columns1);
    printf("dsp_res %d X %d\n:", rows0, columns1);
    print_matrix_s16(dst_opt, rows0, 2 * columns1);
  }

  ret = compare_s16(dst_ref, dst_opt, rows0 * columns1 * 2);

  print_performance(ref_tic_count, ref_instruction_count, tic_count, instruction_count, input_bytes, ti_tics);

  if (ret == 0)
    printf(" passed |\n");
  else
    printf(" failed |\n");

  return ret;
}

int test_mat_mul_cplx32(int32_t* src0, int32_t* src1, int32_t* dst_ref, int32_t* dst_opt, int rows0, int columns0,
                        int columns1, int print, int32_t shift, int32_t input_bytes, int32_t ti_tics) {
  int ret = 0;

  uint32_t tic_count[2], instruction_count[2];
  uint32_t ref_tic_count[2], ref_instruction_count[2];

  count_tics(ref_tic_count, ref_instruction_count);
  ref_mat_mul_cplx32(src0, rows0, columns0, src1, columns1, dst_ref, shift);
  count_tics(&ref_tic_count[1], &ref_instruction_count[1]);

  count_tics(tic_count, instruction_count);
  mat_mul_cplx32(src0, rows0, columns0, src1, columns1, dst_opt, shift);
  count_tics(&tic_count[1], &instruction_count[1]);

  if (print) {
    printf("mat1 %d X %d\n:", rows0, columns0);
    print_matrix_s32(src0, rows0, 2 * columns0);
    printf("mat2 %d X %d\n:", columns0, columns1);
    print_matrix_s32(src1, columns0, 2 * columns1);

    printf("ref_res %d X %d\n:", rows0, columns1);
    print_matrix_s32(dst_ref, rows0, 2 * columns1);
    printf("dsp_res %d X %d\n:", rows0, columns1);
    print_matrix_s32(dst_opt, rows0, 2 * columns1);
  }

  ret = compare_s32(dst_ref, dst_opt, rows0 * columns1 * 2);

  print_performance(ref_tic_count, ref_instruction_count, tic_count, instruction_count, input_bytes, ti_tics);

  if (ret == 0)
    printf(" passed |\n");
  else
    printf(" failed |\n");

  return ret;
}

int test_mat_mul_cplx_fl(float* src0, float* src1, float* dst_ref, float* dst_opt, int rows0, int columns0,
                         int columns1, int print, int32_t shift, int32_t input_bytes, int32_t ti_tics) {
  int ret = 0;

  uint32_t tic_count[2], instruction_count[2];
  uint32_t ref_tic_count[2], ref_instruction_count[2];

  count_tics(ref_tic_count, ref_instruction_count);
  ref_mat_mul_cplx_fl(src0, rows0, columns0, src1, columns1, dst_ref);
  count_tics(&ref_tic_count[1], &ref_instruction_count[1]);

  count_tics(tic_count, instruction_count);
  mat_mul_cplx_fl(src0, rows0, columns0, src1, columns1, dst_opt);
  count_tics(&tic_count[1], &instruction_count[1]);

  if (print) {
    printf("mat1 %d X %d\n:", rows0, columns0);
    print_matrix_float(src0, rows0, 2 * columns0);
    printf("mat2 %d X %d\n:", columns0, columns1);
    print_matrix_float(src1, columns0, 2 * columns1);

    printf("ref_res %d X %d\n:", rows0, columns1);
    print_matrix_float(dst_ref, rows0, 2 * columns1);
    printf("dsp_res %d X %d\n:", rows0, columns1);
    print_matrix_float(dst_opt, rows0, 2 * columns1);
  }

#ifdef EPS
  ret = compare_float_eps(dst_ref, dst_opt, rows0 * columns1 * 2, EPS);
#else
  ret = compare_float(dst_ref, dst_opt, rows0 * columns1 * 2);
#endif

  print_performance(ref_tic_count, ref_instruction_count, tic_count, instruction_count, input_bytes, ti_tics);

  if (ret == 0)
    printf(" passed |\n");
  else
    printf(" failed |\n");

  return ret;
}

int test_mat_mul_cplx_db(double* src0, double* src1, double* dst_ref, double* dst_opt, int rows0, int columns0,
                         int columns1, int print, int32_t shift, int32_t input_bytes, int32_t ti_tics) {
  int ret = 0;

  uint32_t tic_count[2], instruction_count[2];
  uint32_t ref_tic_count[2], ref_instruction_count[2];

  count_tics(ref_tic_count, ref_instruction_count);
  ref_mat_mul_cplx_db(src0, rows0, columns0, src1, columns1, dst_ref);
  count_tics(&ref_tic_count[1], &ref_instruction_count[1]);

  count_tics(tic_count, instruction_count);
  mat_mul_cplx_db(src0, rows0, columns0, src1, columns1, dst_opt);
  count_tics(&tic_count[1], &instruction_count[1]);

  if (print) {
    printf("mat1 %d X %d\n:", rows0, columns0);
    print_matrix_double(src0, rows0, 2 * columns0);
    printf("mat2 %d X %d\n:", columns0, columns1);
    print_matrix_double(src1, columns0, 2 * columns1);

    printf("ref_res %d X %d\n:", rows0, columns1);
    print_matrix_double(dst_ref, rows0, 2 * columns1);
    printf("dsp_res %d X %d\n:", rows0, columns1);
    print_matrix_double(dst_opt, rows0, 2 * columns1);
  }

#ifdef EPS
  ret = compare_double_eps(dst_ref, dst_opt, rows0 * columns1 * 2, EPS);
#else
  ret = compare_double(dst_ref, dst_opt, rows0 * columns1 * 2);
#endif

  print_performance(ref_tic_count, ref_instruction_count, tic_count, instruction_count, input_bytes, ti_tics);

  if (ret == 0)
    printf(" passed |\n");
  else
    printf(" failed |\n");

  return ret;
}
