// Copyright 2025 RnD Center "ELVEES", JSC

#include <malloc.h>
#include <stdio.h>
#include <string.h>

#include <cmath>

#include "convolve_tests_helper.h"

#define FL_16

#define EPS 0.001

int array_src_b[] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
int array_src_h[] = {112, 17, 17, 28, 28, 8, 8, 8, 8, 8};
int array_src_w[] = {112, 17, 17, 28, 28, 8, 8, 8, 8, 8};
int array_src_ch[] = {28, 128, 128, 128, 128, 384, 384, 384, 384, 384};
int array_kerX[] = {3, 1, 7, 3, 3, 3, 3, 3, 3, 3};
int array_kerY[] = {3, 7, 1, 3, 3, 1, 1, 1, 1, 1};
int array_dst_b[] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
int array_dst_h[] = {112, 17, 17, 28, 28, 8, 8, 8, 8, 8};
int array_dst_w[] = {112, 17, 17, 28, 28, 8, 8, 8, 8, 8};
int array_dst_ch[] = {28, 128, 128, 128, 256, 384, 32, 64, 96, 160};
int array_stride_X[] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
int array_stride_Y[] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
int bias_flags[] = {1, 1, 1, 0, 1, 1, 1, 1, 1, 1};
int array_padX[] = {1, 0, 3, 1, 1, 1, 1, 1, 1, 1};
int array_padY[] = {1, 3, 0, 1, 1, 0, 0, 0, 0, 0};
int array_padW[] = {1, 0, 3, 1, 1, 1, 1, 1, 1, 1};
int array_padH[] = {1, 3, 0, 1, 1, 0, 0, 0, 0, 0};
Store_version st_vers[] = {STORE_NONE, STORE_RELU, STORE_RELU6};

int main() {
  disable_l2_cache();

  int test_result = 0;

  printf(
      "| func_name        | input (b, h, w, ch)    | bias | store mode |"
      " opt mul/tic | status |\n");
  printf(
      "|                  | ker (X, Y), str (X, Y) |      |            |"
      "             |        |\n");
  printf(
      "|                  | output (b, h, w, ch)   |      |            |"
      "             |        |\n");
  printf(
      "-----------------------------------------------------------------"
      "-----------------------\n");

  change_L1DC_ctrl();
  Store_version st_ver = STORE_NONE;

  for (int indx = 0; indx < 10; ++indx) {
    printf("| convolution fl16 | (%1d %4d %4d %4d)     |    %d |", array_src_b[indx], array_src_h[indx],
           array_src_w[indx], array_src_ch[indx], bias_flags[indx]);

    st_ver = st_vers[indx % 3];

    if (st_ver == STORE_NONE)
      printf("      none  | ");
    else if (st_ver == STORE_RELU)
      printf("      relu  | ");
    else
      printf("      relu6 | ");

    int result = 0;
    int padY, padX, padH, padW;

    padY = padX = padH = padW = 0;
    padX = array_padX[indx];
    padY = array_padY[indx];
    padW = array_padW[indx];
    padH = array_padH[indx];

    // шаги ядра
    int strideY = array_stride_Y[indx];
    int strideX = array_stride_X[indx];

    int src_ch_for_ref_ver = array_src_ch[indx];
    if (array_src_ch[indx] % 8) {
      src_ch_for_ref_ver += 8 - (array_src_ch[indx] % 8);
    }

    // data, b, h,  w, c
    Tensor_fl32 input = {NULL, array_src_b[indx], array_src_h[indx], array_src_w[indx], array_src_ch[indx]};
    input.data =
        (float *)memalign(64, sizeof(float) * input.batch * input.height * input.width * input.channel * COEF);
    set_data_tensor(&input);

    // тензор кратного размера
    Tensor_fl32 input_for_ref_ver = {NULL, array_src_b[indx], array_src_h[indx], array_src_w[indx],
                                     src_ch_for_ref_ver};
    input_for_ref_ver.data = (float *)memalign(64, sizeof(float) * input_for_ref_ver.batch * input_for_ref_ver.height *
                                                       input_for_ref_ver.width * input_for_ref_ver.channel * COEF);

    tensor_data_copy(&input, &input_for_ref_ver);

    Tensor_fl16 input_fl16_for_ref = {NULL, input.batch, input.height, input.width, input_for_ref_ver.channel};
    input_fl16_for_ref.data =
        (uint16_t *)memalign(64, input_fl16_for_ref.batch * input_fl16_for_ref.height * input_fl16_for_ref.width *
                                     input_fl16_for_ref.channel * sizeof(uint16_t) * COEF);
    float32_to_float16(
        input_for_ref_ver.data, input_fl16_for_ref.data,
        input_fl16_for_ref.batch * input_fl16_for_ref.height * input_fl16_for_ref.width * input_fl16_for_ref.channel);

    Tensor_fl16 input_fl16 = {NULL, input.batch, input.height, input.width, input.channel};
    input_fl16.data = (uint16_t *)memalign(
        64, input_fl16.batch * input_fl16.height * input_fl16.width * input_fl16.channel * sizeof(uint16_t) * COEF);
    float32_to_float16(input.data, input_fl16.data, input.batch * input.height * input.width * input.channel);

    Tensor_fl32 output_ref = {NULL, array_dst_b[indx], array_dst_h[indx], array_dst_w[indx], array_dst_ch[indx]};
    output_ref.data = (float *)memalign(
        64, sizeof(float) * output_ref.batch * output_ref.height * output_ref.width * output_ref.channel * COEF);
    memset(output_ref.data, 0,
           sizeof(float) * output_ref.batch * output_ref.height * output_ref.width * output_ref.channel * COEF);

    Tensor_fl32 output_opt = {NULL, output_ref.batch, output_ref.height, output_ref.width, output_ref.channel};
    output_opt.data = (float *)memalign(
        64, sizeof(float) * output_opt.batch * output_opt.height * output_opt.width * output_opt.channel * COEF);
    memset(output_opt.data, 0,
           sizeof(float) * output_opt.batch * output_opt.height * output_opt.width * output_opt.channel * COEF);

    Tensor_fl16 output_opt_fl16 = {NULL, output_ref.batch, output_ref.height, output_ref.width, output_ref.channel};
    output_opt_fl16.data = (uint16_t *)calloc(
        output_opt_fl16.batch * output_opt_fl16.height * output_opt_fl16.width * output_opt_fl16.channel,
        sizeof(uint16_t) * COEF);

    // data, X, Y
    Weight_fl32 weight_dyxc = {NULL, array_kerX[indx], array_kerY[indx], input.channel, output_ref.channel};
    Weight_fl32 weight_yxcd = {NULL, weight_dyxc.kernelX, weight_dyxc.kernelY, input.channel, output_ref.channel};
    weight_dyxc.data = (float *)memalign(
        64, sizeof(float) * weight_dyxc.kernelX * weight_dyxc.kernelY * input.channel * output_ref.channel);
    set_data_weight_dyxc(&weight_dyxc);
    weight_yxcd.data = (float *)memalign(
        64, sizeof(float) * weight_yxcd.kernelX * weight_yxcd.kernelY * input.channel * output_ref.channel);
    weight_dyxc_to_yxcd(&weight_dyxc, &weight_yxcd);

    Weight_fl32 weight_yxcd_ref_ver = {NULL, weight_dyxc.kernelX, weight_dyxc.kernelY, input_for_ref_ver.channel,
                                       output_ref.channel};
    weight_yxcd_ref_ver.data = (float *)memalign(64, sizeof(float) * weight_yxcd.kernelX * weight_yxcd.kernelY *
                                                         input_for_ref_ver.channel * output_ref.channel);

    Weight_fl16 weight_yxcd_fl16_ref_ver = {NULL, weight_dyxc.kernelX, weight_dyxc.kernelY, input_for_ref_ver.channel,
                                            output_ref.channel};
    weight_yxcd_fl16_ref_ver.data =
        (uint16_t *)memalign(64, sizeof(uint16_t) * weight_yxcd.kernelX * weight_yxcd.kernelY *
                                     input_for_ref_ver.channel * output_ref.channel);

    weight_data_copy(&weight_yxcd, &weight_yxcd_ref_ver);

    float32_to_float16(weight_yxcd_ref_ver.data, weight_yxcd_fl16_ref_ver.data,
                       weight_yxcd.kernelX * weight_yxcd.kernelY * input_for_ref_ver.channel * output_ref.channel);

    // растяжение свертки
    int dilationY = 1;
    int dilationX = 1;

    Weight_fl16 weight_yxcd_fl16 = {NULL,          weight_dyxc.kernelX, weight_dyxc.kernelY,
                                    input.channel, output_ref.channel,  strideX,
                                    strideY,       dilationX,           dilationY};

    weight_yxcd_fl16.data = (uint16_t *)memalign(
        64, sizeof(uint16_t) * weight_yxcd.kernelX * weight_yxcd.kernelY * input.channel * output_ref.channel);
    float32_to_float16(weight_yxcd.data, weight_yxcd_fl16.data,
                       weight_yxcd.kernelX * weight_yxcd.kernelY * input.channel * output_ref.channel);

    float *bias = (float *)memalign(64, output_ref.channel * sizeof(float));

    for (int i = 0; i < output_ref.channel; ++i) {
      bias[i] = (bias_flags[indx]) ? (rand() % 3) : 0;
    }

    int group = 1;

    ref_convolution(input.data, input.batch, input.height, input.width, input.channel, weight_dyxc.kernelY,
                    weight_dyxc.kernelX, dilationY, dilationX, strideY, strideX, padY, padX, padH, padW, group,
                    weight_dyxc.data, bias, output_ref.data, output_ref.channel);

    memset(output_opt_fl16.data, 0,
           output_opt_fl16.batch * output_opt_fl16.height * output_opt_fl16.width * output_opt_fl16.channel *
               sizeof(uint16_t) * COEF);

    flush_all_caches();

    memset(output_opt_fl16.data, 0,
           output_opt_fl16.batch * output_opt_fl16.height * output_opt_fl16.width * output_opt_fl16.channel *
               sizeof(uint16_t) * COEF);

    if (st_ver != STORE_NONE) {
      for (int i = 0; i < output_ref.height; ++i) {
        for (int j = 0; j < output_ref.width; ++j) {
          for (int k = 0; k < output_ref.channel; ++k) {
            int res_indx = i * output_ref.width * output_ref.channel + j * output_ref.channel + k;
            if (output_ref.data[res_indx] < 0.0) output_ref.data[res_indx] = 0.0;

            if (st_ver == STORE_RELU6) {
              if (output_ref.data[res_indx] > 6.0) output_ref.data[res_indx] = 6.0;
            }
          }
        }
      }
    }

    /* Подготовка данных для запуска опт. версии */
    ConvFl16Config config;
    int tic[2], instr[2], func_tic, func_instr;

    init_dma_chain_conv_fl16(&input_fl16, group, padX, padW, padY, padH, &weight_yxcd_fl16, bias, &output_opt_fl16,
                             &config, bias_flags[indx], (uint16_t *)&__local_mem, 524288);
    count_tics(tic, instr);
    run_conv_fl16(&input_fl16, group, &weight_yxcd_fl16, bias, &output_opt_fl16, &config, st_ver, bias_flags[indx]);
    count_tics(&tic[1], &instr[1]);

    destroy_dma_chain_conv_fl16(&config);

    flush_all_caches();

    func_tic = tic[1] - tic[0];
    func_instr = instr[1] - instr[0];

    printf("   %7.3f  | ", (float)array_dst_h[indx] * array_dst_w[indx] * array_src_ch[indx] * array_kerX[indx] *
                               array_kerY[indx] * array_dst_ch[indx] / func_tic);

    float16_to_float32(
        output_opt_fl16.data, output_opt.data,
        output_opt_fl16.batch * output_opt_fl16.height * output_opt_fl16.width * output_opt_fl16.channel * COEF);

    result += comparator_fl32(output_ref.data, output_opt.data,
                              output_ref.batch * output_ref.height * output_ref.width * output_ref.channel * COEF,
                              (float)EPS);

    if (!result)
      printf("passed |\n");
    else
      printf("failed |\n");

    test_result += result;

    printf(
        "|                  | (%1d %2d) (%1d %2d)          |      |            "
        "|             |        |\n",
        array_kerX[indx], array_kerY[indx], array_stride_X[indx], array_stride_Y[indx]);
    printf(
        "|                  | (%1d %4d %4d %4d)     |      |            |      "
        "       |        |\n",
        array_dst_b[indx], array_dst_h[indx], array_dst_w[indx], array_dst_ch[indx]);
    printf(
        "----------------------------------------------------------------------"
        "------------------\n");

    free(weight_yxcd_fl16.data);
    free(output_opt_fl16.data);
    free(input_fl16.data);

    free(output_opt.data);
    free(input.data);
    free(output_ref.data);
    free(weight_dyxc.data);
    free(weight_yxcd.data);
    free(bias);
  }

  enable_l2_cache(L2_CACHE_SIZE);
  return test_result;
}
