// Copyright 2025 RnD Center "ELVEES", JSC

#include <elcore50-matrix-lib/mat_mul_with_dma_fl16.hpp>

#include "convolve_tests_helper.h"

extern "C" void mm_v0_vliw_1_sub_matrix_pre_load_real_out_offset(float* in1, int row, int col0row1, float* in2,
                                                                 int col1, float* out, int x_offset, int y_offset,
                                                                 int real_row, int real_col, int* tics, int* instr,
                                                                 int offsetA, int offsetB);

int main() {
  disable_l2_cache();

#ifdef DEBUG
  printf("*** DEBUG_OPTION ***\n");
#endif

  int result = 0;
  int test_result = 0;

  int array_row[] = {3136, 3136, 49, 49, 128, 256, 512, 512, 512, 8192};
  int array_row1col0[] = {64, 64, 512, 2048, 128, 128, 128, 1024, 512, 512};
  int array_col1[] = {64, 256, 2048, 512, 64, 32, 128, 256, 512, 512};

  Store_version st_ver = STORE_NONE;

  printf(
      "| func_name         | size (r0, c1r1, c1) | ref mul/tic | opt mul/tic | "
      "status |\n");
  printf(
      "------------------------------------------------------------------------"
      "-------\n");
  char store_version[] = "relu6";

  for (int v = 0; v < 3; ++v) {
    if (v == 0) {
      strcpy(store_version, "none ");
      st_ver = STORE_NONE;
    }
    if (v == 1) {
      strcpy(store_version, "relu ");
      st_ver = STORE_RELU;
    }
    if (v == 2) {
      strcpy(store_version, "relu6");
      st_ver = STORE_RELU6;
    }

    for (int i = 0; i < 10; ++i) {
      printf("| mat_mul_f16 %s |", store_version);

      int row = array_row[i];
      int row1col0 = array_row1col0[i];
      int col1 = array_col1[i];
      printf("   %5d %5d %5d |", row, row1col0, col1);

      int offsetA = 0;
      int offsetC = 0;

      int tic[2], instr[2], block_tic[6], block_instr[6], func_tic, func_instr;

      float* src0_fl32 = (float*)memalign(64, row * (row1col0 + offsetA) * sizeof(float));
      uint16_t* src0_fl16 = (uint16_t*)memalign(64, row * (row1col0 + offsetA) * sizeof(uint16_t));

      float* src1_fl32 = (float*)memalign(64, row1col0 * col1 * sizeof(float));

      float* init_vector = (float*)memalign(64, col1 * sizeof(float));
      uint16_t* src1_fl16 = (uint16_t*)memalign(64, row1col0 * col1 * sizeof(uint16_t));

      float* dst_fl32_ref = (float*)memalign(64, row * (col1 + offsetC) * COEF * sizeof(float));
      uint16_t* dst_fl16 = (uint16_t*)memalign(64, row * (col1 + offsetC) * COEF * sizeof(uint16_t));
      float* dst_fl32_opt = (float*)memalign(64, row * (col1 + offsetC) * COEF * sizeof(float));

#ifndef NO_CALC
      for (int i = 0; i < row * (row1col0 + offsetA); ++i) {
        src0_fl32[i] = (rand() % 5 - 2) * 0.5;
      }

      for (int i = 0; i < row1col0 * col1; ++i) {
        src1_fl32[i] = (rand() % 5 - 2) * 0.5;
      }

      for (int i = 0; i < col1; ++i) {
        init_vector[i] = (rand() % 5 - 2) * 0.5;
      }
#endif

      float32_to_float16(src0_fl32, src0_fl16, row * (row1col0 + offsetA));
      float32_to_float16(src1_fl32, src1_fl16, row1col0 * col1);

      memset(dst_fl32_ref, 0, row * (col1 + offsetC) * COEF * sizeof(float));
      memset(dst_fl32_opt, 0, row * (col1 + offsetC) * COEF * sizeof(float));

      /* запуск референсной версии */
      flush_all_caches();
      count_tics(tic, instr);
#ifndef NO_CALC
      mm_v0_vliw_1_sub_matrix_pre_load_real_out_offset(src0_fl32, row, row1col0, src1_fl32, col1, dst_fl32_ref, 0, 0,
                                                       0, col1 + offsetC, block_tic, block_instr, offsetA, 0);
#endif
      count_tics(&tic[1], &instr[1]);

      for (int i = 0; i < row; ++i) {
        for (int j = 0; j < col1; ++j) {
          dst_fl32_ref[i * (col1 + offsetC) + j] += init_vector[j];
        }
      }

      if (st_ver != STORE_NONE) {
        for (int i = 0; i < row; ++i) {
          for (int j = 0; j < col1; ++j) {
            if (dst_fl32_ref[i * (col1 + offsetC) + j] < 0.0) dst_fl32_ref[i * (col1 + offsetC) + j] = 0.0;

            if (st_ver == STORE_RELU6) {
              if (dst_fl32_ref[i * (col1 + offsetC) + j] > 6.0) dst_fl32_ref[i * (col1 + offsetC) + j] = 6.0;
            }
          }
        }
      }

      func_tic = tic[1] - tic[0];
      func_instr = instr[1] - instr[0];

      printf("     %7.3f |", (float)row * row1col0 * col1 / func_tic);

      memset(dst_fl16, 0, row * (col1 + offsetC) * COEF * sizeof(uint16_t));

      flush_all_caches();
      int offset_A = 0;
      int offset_B = 0;

      MatMulFl16Config config;
      init_dma_chain_matmul_fl16(src0_fl16, row, row1col0, src1_fl16, col1, dst_fl16, offset_A, offset_B, &config,
                                 init_vector, offsetA, offsetC, (uint16_t*)&__local_mem);

      /* запуск опт. версии */
      flush_all_caches();

      count_tics(tic, instr);
      run_matmul_fl16(src0_fl16, row, row1col0, src1_fl16, col1, dst_fl16, offset_A, offset_B, &config, st_ver,
                      init_vector);
      count_tics(&tic[1], &instr[1]);

      destroy_dma_chain_mat_mul_fl16(&config);

      func_tic = tic[1] - tic[0];
      func_instr = instr[1] - instr[0];

      printf("     %7.3f |", (float)row * row1col0 * col1 / func_tic);

      float16_to_float32(dst_fl16, dst_fl32_opt, row * (col1 + offsetC) * COEF);

      result += memcmp(dst_fl32_opt, dst_fl32_ref, row * (col1 + offsetC) * sizeof(float) * COEF);

      if (!result)
        printf(" passed |\n");
      else
        printf(" failed |\n");

      free(src0_fl16);
      free(src1_fl16);
      free(src0_fl32);
      free(src1_fl32);
      free(dst_fl16);
      free(dst_fl32_opt);
      free(dst_fl32_ref);
      free(init_vector);
      test_result += result;
      result = 0;
    }
  }

  printf(
      "------------------------------------------------------------------------"
      "-------\n");

  enable_l2_cache(L2_CACHE_SIZE);
  return test_result;
}
