// Copyright 2025 RnD Center "ELVEES", JSC

#include "convolve_tests_helper.h"

void set_data_tensor(Tensor_fl32* tensor) {
  int indx = 0;
  for (int b = 0; b < tensor->batch; ++b) {
    for (int h = 0; h < tensor->height; ++h) {
      for (int w = 0; w < tensor->width; ++w) {
        for (int ch = 0; ch < tensor->channel; ++ch) {
          indx = b * tensor->height * tensor->width * tensor->channel + h * tensor->width * tensor->channel +
                 w * tensor->channel + ch;

          tensor->data[indx] = rand() % 3 - 1;
        }
      }
    }
  }
}

void set_data_weight_dyxc(Weight_fl32* weight) {
  int indx = 0;
  for (int b = 0; b < weight->dstC; ++b) {
    for (int h = 0; h < weight->kernelY; ++h) {
      for (int w = 0; w < weight->kernelX; ++w) {
        for (int ch = 0; ch < weight->srcC; ++ch) {
          indx = b * weight->kernelY * weight->kernelX * weight->srcC + h * weight->kernelX * weight->srcC +
                 w * weight->srcC + ch;

          weight->data[indx] = rand() % 3 - 1;
        }
      }
    }
  }
}

void weight_dyxc_to_yxcd(const Weight_fl32* input, Weight_fl32* output) {
  int indx_in = 0;
  int indx_out = 0;
  for (int h = 0; h < output->kernelY; ++h) {
    for (int w = 0; w < output->kernelX; ++w) {
      for (int ch = 0; ch < output->srcC; ++ch) {
        for (int b = 0; b < output->dstC; ++b) {
          indx_out = h * output->dstC * output->kernelX * output->srcC + w * output->srcC * output->dstC +
                     ch * output->dstC + b;
          indx_in = b * output->kernelY * output->kernelX * output->srcC + h * output->kernelX * output->srcC +
                    w * output->srcC + ch;
          output->data[indx_out] = input->data[indx_in];
        }
      }
    }
  }
}

void print_tensor(Tensor_fl32* tensor) {
  for (int b = 0; b < tensor->batch; ++b) {
    for (int h = 0; h < tensor->height; ++h) {
      for (int w = 0; w < tensor->width; ++w) {
        for (int ch = 0; ch < 1; ++ch) {
          printf("%5.1f ", tensor->data[b * tensor->height * tensor->channel * tensor->width +
                                        h * tensor->channel * tensor->width + w * tensor->channel + ch]);
        }

        printf("|");
      }
      printf("\n");
    }
    printf("\n");
  }
}

int comparator_fl32(float* src0, float* src1, int size, float eps) {
  int err_counter = 0;
  for (int i = 0; i < size; ++i) {
    if (src0[i] != src0[i] || src1[i] != src1[i]) {
      printf("NAN: [%d] %f != %f\n", i, src0[i], src1[i]);
      return 1;
    }

    if (src0[i] != 0 && src1[i] != 0 && err_counter < 10) {
      if ((fabs(src0[i] - src1[i]) / MIN(fabs(src0[i]), fabs(src1[i]))) > eps) {
        printf("Error: [%d] %f != %f err = %f\n", i, src0[i], src1[i],
               (fabs(src0[i] - src1[i]) / MIN(fabs(src0[i]), fabs(src1[i]))));
        ++err_counter;
      }
    } else {
      if (MAX(fabs(src0[i]), fabs(src1[i])) > eps && err_counter < 10) {
        printf("Error: [%d] %f != %f\n", i, src0[i], src1[i]);
        ++err_counter;
      }
    }
  }

  return err_counter;
}

void tensor_data_copy(Tensor_fl32* src, Tensor_fl32* dst) {
  memset(dst->data, 0, dst->batch * dst->height * dst->width * dst->channel * sizeof(float));

  int indx_src, indx_dst;
  for (int b = 0; b < src->batch; ++b) {
    for (int h = 0; h < src->height; ++h) {
      for (int w = 0; w < src->width; ++w) {
        for (int ch = 0; ch < src->channel; ++ch) {
          indx_src =
              b * src->height * src->width * src->channel + h * src->width * src->channel + w * src->channel + ch;

          indx_dst =
              b * dst->height * dst->width * dst->channel + h * dst->width * dst->channel + w * dst->channel + ch;

          dst->data[indx_dst] = src->data[indx_src];
        }
      }
    }
  }
}

void weight_data_copy(Weight_fl32* src, Weight_fl32* dst) {
  memset(dst->data, 0, dst->dstC * dst->srcC * dst->kernelX * dst->kernelY * sizeof(float));

  int indx_src = 0;
  int indx_dst = 0;
  for (int h = 0; h < src->kernelY; ++h) {
    for (int w = 0; w < src->kernelX; ++w) {
      for (int ch = 0; ch < src->srcC; ++ch) {
        for (int ch_d = 0; ch_d < src->dstC; ++ch_d) {
          indx_src = ch_d + ch * src->dstC + w * src->dstC * src->srcC + h * src->dstC * src->srcC * src->kernelX;

          indx_dst = ch_d + ch * dst->dstC + w * dst->dstC * dst->srcC + h * dst->dstC * dst->srcC * dst->kernelX;

          dst->data[indx_dst] = src->data[indx_src];
        }
      }
    }
  }
}
