// Copyright 2025 RnD Center "ELVEES", JSC

/*! \file
 *  \brief Тестирование двумерной линейной свертки
 */

#include <elcore50-signal-lib/fft.h>

#include "argparser.h"
#include "common.h"

void print_usage() {
  printf("Usage: test_conv_2d [OPTIONS]\n");
  printf("Options description in strict order:\n");
  argparser_print_usage();
  printf("\t-m    Data location: {DDR, XYRAM}\n");
  printf("\t-r0   Number of rows of the first matrix:\n");
  printf("\t-c0   Number of columns of the first matrix\n");
  printf("\t-r1   Number of rows of the second matrix:\n");
  printf("\t     \tTotal rows number of convolution:\n");
  printf("\t     \t  16 <= rows0 + rows1 - 1 <= 512\n");
  printf("\t-c1   Number of columns of the second matrix:\n");
  printf("\t     \tTotal cols number of convolution:\n");
  printf("\t     \t  16 <= cols0 + cols1 - 1 <= 512 for 2d convolution\n");
  printf("\t-h   Print usage\n");
}

int main(int argc, char *argv[]) {
  if ((argc != 2) && (argc != 16)) {
    printf("Error: wrong number of parameters!\n");
    print_usage();
    return 1;
  }

  char *key_h = argv[1];
  if (!strcmp(key_h, "-h")) {
    print_usage();
    return 0;
  }

  char *infiles[MAX_FILES];
  char *outfiles[MAX_FILES];
  int8_t len = argparser(argc, argv, infiles, outfiles);
  if (len != 3) {
    printf("Error: Wrong amount of input files!\n");
    return 1;
  }

  int8_t is_data_in_xyram;
  int16_t rows0;
  int32_t cols0;
  int16_t rows1;
  int32_t cols1;
  for (size_t i = 6; i < argc; i += 2) {
    char *key = argv[i];
    if (!strcmp(key, "-m")) {
      char *mem_type = argv[i + 1];
      if (strcmp(mem_type, "DDR") && strcmp(mem_type, "XYRAM")) {
        printf("Unrecognized program option: %s\n", mem_type);
        print_usage();
        return 1;
      }
      is_data_in_xyram = strcmp(mem_type, "XYRAM") ? 0 : 1;
    } else if (!strcmp(key, "-r0")) {
      rows0 = atoi(argv[i + 1]);
    } else if (!strcmp(key, "-c0")) {
      cols0 = atoi(argv[i + 1]);
    } else if (!strcmp(key, "-r1")) {
      rows1 = atoi(argv[i + 1]);
      if (rows0 + rows1 - 1 > 512) {
        printf("Error: wrong total rows number of convolution!\n");
        print_usage();
        return 1;
      }
    } else if (!strcmp(key, "-c1")) {
      cols1 = atoi(argv[i + 1]);
      if (cols0 + cols1 - 1 > 512) {
        printf("Error: wrong total cols number of convolution!\n");
        print_usage();
        return 1;
      }
    } else {
      printf("Unrecognized program option: %s\n", key);
      print_usage();
      return 1;
    }
  }

  if (is_data_in_xyram)
    printf("start FFT-convolution test 2d XYRAM\n");
  else
    printf("start FFT-convolution test 2d DDR\n");

  printf("   rows0 |   cols0 |   rows1 |   cols1 | conv_rows | conv_cols |     ticks | status\n");
  printf("-----------------------------------------------------------------------------------\n");

  int16_t errors_count = 0;
  const int32_t conv_rows = rows0 + rows1 - 1;
  const int32_t conv_cols = cols0 + cols1 - 1;
  const int32_t conv_size = conv_rows * conv_cols;
  const int32_t conv_rows2 = 1 << (int)ceil(log2f(conv_rows));
  const int32_t conv_cols2 = 1 << (int)ceil(log2f(conv_cols));
  const int32_t conv_size2 = conv_rows2 * conv_cols2;

  float *src0 = memalign(conv_size2 * 2 * sizeof(float), conv_size2 * 2 * sizeof(float));
  if (read_data_from_file(src0, conv_size2 * 2, infiles[0])) {
    free(src0);
    return 1;
  }
  float *src1 = memalign(conv_size2 * 2 * sizeof(float), conv_size2 * 2 * sizeof(float));
  if (read_data_from_file(src1, conv_size2 * 2, infiles[1])) {
    free(src0);
    free(src1);
    return 1;
  }

  float *dst = memalign(conv_size2 * 2 * sizeof(float), conv_size2 * 2 * sizeof(float));
  int32_t ticks_start, ticks_end;
  int32_t instrs_start, instrs_end;
  ticks_counter(&ticks_start, &instrs_start);

  // FFT-convolution
  int retval = fft_conv_2d(src0, src1, dst, conv_rows2, conv_cols2, is_data_in_xyram);

  ticks_counter(&ticks_end, &instrs_end);
  int32_t ticks = ticks_end - ticks_start;

  free(src0);
  free(src1);

  if (retval) {
    print_error_message(retval);
    errors_count = 1;
  } else {
    // Compare results
    float *ethalon = malloc(conv_size * 2 * sizeof(float));
    if (read_data_from_file(ethalon, conv_size * 2, outfiles[0])) {
      free(ethalon);
      free(dst);
      return 1;
    }
    float norm = maximum_norm_conv2d(dst, ethalon, conv_rows, conv_cols);
    printf("%8d |%8d |%8d |%8d |%10d |%10d |%10d | ", rows0, cols0, rows1, cols1, conv_rows2, conv_cols2, ticks);
    if (norm < 1e-1) {
      printf("passed\n");
    } else {
      errors_count += 1;
      printf("failed\n");
    }
    free(ethalon);
  }

  free(dst);

  printf("end FFT-convolution 2d test\n");

  return errors_count;
}
