// Copyright 2025 RnD Center "ELVEES", JSC

#ifndef CONVOLUTION_SMALL_ASM_KER_H
#define CONVOLUTION_SMALL_ASM_KER_H

#include "elcore50-matrix-lib/common.h"
#include "elcore50-matrix-lib/elcore50.h"

// Версии ядер свертки
enum Conv_version { VER_1x16x128, VER_4x4x128, VER_4x8x64, VER_8x8x32 };

/// Структура для запуска потайлового алгоритма свертки формата float16
typedef struct ConvFl16Config {
  uint16_t* buf_A[2];   ///< указатель на буферы тайловой обработки первой матрицы
  uint16_t* buf_B[2];   ///< указатель на буферы тайловой обработки второй матрицы
  uint16_t* buf_C[2];   ///< указатель на буферы тайловой обработки
                        ///< результирующей матрицы
  float* buf_init_vec;  ///< указатель на буфер с инициализирующим вектором
  int buf_srcH;         ///< высота буфера для входного тензора
  int buf_srcW;         ///< ширина буфера для входного тензора
  int buf_srcC;         ///< к-во каналов буфера для входного тензора
  int buf_dstH;         ///< высота буфера для выходного тензора
  int buf_dstW;         ///< ширина буфера для выходного тензора
  int buf_dstC;         ///< к-во каналов буфера для выходного тензора
  int offset_A;         ///< смещение для входного тензора
  int tile_padX;        ///< паддинг тайла
  int tile_padY;        ///< паддинг тайла
  int tile_padW;        ///< паддинг тайла
  int tile_padH;        ///< паддинг тайла

  VDMAChain chain_A;         ///< dma цепочки для тайлов входного тензора
  VDMAChain chain_B;         ///< dma цепочки для тайлов весов
  VDMAChain chain_ld_C;      ///< dma цепочки для тайлов загрузки результата
  VDMAChain chain_st_C;      ///< dma цепочки для тайлов выгрузки результата
  VDMAChain chain_init_vec;  ///< dma цепочки для тайлов инициализирующего вектора
  int len_chain_A;           ///< к-во dma цепочек для тайлов входного тензора
  int len_chain_B;           ///< к-во dma цепочек для тайлов весов
  int len_chain_ld_C;        ///< к-во dma цепочек для загрузки тайлов результата
  int len_chain_st_C;        ///< к-во dma цепочек для выгрузки тайлов результата

  uint16_t* ptr_zeros;      ///< указатель на нулевые данные для быстрого заполнения
                            ///< данных
  float* ptr_inter_result;  ///< указатель память для хранения промежуточного
                            ///< результата
} ConvFl16Config;

/// Инициализация и заполнение конфигурационной структуры для алгоритма свертки
/// формата float16
void init_dma_chain_conv_fl16(Tensor_fl16* input_tensor,   ///< [in]  структура для входного тензора
                              int group,                   ///< [in]  группа
                              int padX,                    ///< [in]  паддинг слева
                              int padW,                    ///< [in]  паддинг справа
                              int padY,                    ///< [in]  паддинг сверху
                              int padH,                    ///< [in]  паддинг снизу
                              Weight_fl16* input_weight,   ///< [in]  структура весов
                              float* bias,                 ///< [in]  указатель на bias
                              Tensor_fl16* output_tensor,  ///< [in]  структура для выходного тензора
                              ConvFl16Config* config,      ///< [out] структура для запуска потайлового алгоритма
                              bool add_bias_flag,          ///< [in]  флаг наличия bias
                              uint16_t* start_adr,         ///< [in]  адрес начала локальной памяти
                              int local_mem_size           ///< [in]  размер локальной памяти
);

/// Запуск потайловой обработки алгоритма свертки формата float16
void run_conv_fl16(Tensor_fl16* input_tensor,               ///< [in]  структура для входного тензора
                   int group,                               ///< [in]  группа
                   Weight_fl16* input_weight,               ///< [in]  структура весов
                   float* bias,                             ///< [in]  указатель на bias
                   Tensor_fl16* output_tensor,              ///< [out] структура для выходного тензора
                   ConvFl16Config* config,                  ///< [in]  структура для запуска потайлового алгоритма
                   Store_version st_ver,                    ///< [in]  версия постобработки
                   bool add_bias_flag,                      ///< [in]  флаг наличия bias
                   Conv_version conv_version = VER_4x4x128  ///< [in] версия ядра свертки
);

/// Освобождение данных структуры запуска
void destroy_dma_chain_conv_fl16(ConvFl16Config* config  ///< [in] структура для запуска потайловой обработки
);

/// Подбор оптимальных размеров тайлов для алгоритма свертки формата float16
void conv_size_selector(int batch,               ///< [in]  к-во батчей входного тензора
                        int srcH,                ///< [in]  высота входного тензора
                        int srcW,                ///< [in]  ширина входного тензора
                        int srcC,                ///< [in]  к-во каналов входного тензора
                        int kernelY,             ///< [in]  высота ядра свертки
                        int kernelX,             ///< [in]  ширина ядра свертки
                        int dstH,                ///< [in]  высота выходного тензора
                        int dstW,                ///< [in]  ширина выходного тензора
                        int dstC,                ///< [in]  к-во каналов выходного тензора
                        ConvFl16Config* config,  ///< [out] структура для запуска потайлового алгоритма
                        int strideY,             ///< [in]  страйд по высоте
                        int strideX,             ///< [in]  страйд по ширине
                        int count_buf_A,         ///< [in]  к-во буферов для входного тензора
                        int count_buf_B,         ///< [in]  к-во буферов для весов
                        int count_buf_C,         ///< [in]  к-во буферов для результата
                        bool bias_flag,          ///< [in]  флаг наличия bias
                        int local_mem_size,      ///< [in]  указатель на начало локальной памяти
                        int* offset_A            ///< [out] оффсет для входного тензора
);

/// Инициализация и заполнение конфигурационной структуры для алгоритма свертки
/// формата float16
void init_chain_dma_convolution_small_ker_2A2B2C(
    Tensor_fl16* input_tensor,   ///< [in]  структура для входного тензора
    int group,                   ///< [in]  группа
    int padX,                    ///< [in]  паддинг слева
    int padW,                    ///< [in]  паддинг справа
    int padY,                    ///< [in]  паддинг сверху
    int padH,                    ///< [in]  паддинг снизу
    Weight_fl16* input_weight,   ///< [in]  структура весов
    float* bias,                 ///< [in]  указатель на bias
    Tensor_fl16* output_tensor,  ///< [in]  структура для выходного тензора
    ConvFl16Config* config,      ///< [out] структура для запуска потайлового алгоритма
    bool add_bias_flag,          ///< [in]  флаг наличия bias
    uint16_t* start_adr,         ///< [in]  адрес начала локальной памяти
    int local_mem_size           ///< [in]  размер локальной памяти
);

/// Запуск потайловой обработки алгоритма свертки формата float16
void convolution_small_ker_2A2B2C(Tensor_fl16* input_tensor,   ///< [in]  структура для входного тензора
                                  int group,                   ///< [in]  группа
                                  Weight_fl16* input_weight,   ///< [in]  структура весов
                                  float* bias,                 ///< [in]  указатель на bias
                                  Tensor_fl16* output_tensor,  ///< [out] структура для выходного тензора
                                  ConvFl16Config* config,      ///< [in]  структура для запуска потайлового алгоритма
                                  Store_version st_ver,        ///< [in]  версия постобработки
                                  bool add_bias_flag,          ///< [in]  флаг наличия bias
                                  Conv_version conv_version = VER_4x4x128  ///< [in] версия ядра свертки
);

/// Инициализация и заполнение конфигурационной структуры для алгоритма свертки
/// формата float16 Версия с шагом по каналам входа.
void init_chain_dma_convolution_small_ker_2A2B2C_srcC_ver(
    Tensor_fl16* input_tensor,   ///< [in]  структура для входного тензора
    int group,                   ///< [in]  группа
    int padX,                    ///< [in]  паддинг слева
    int padW,                    ///< [in]  паддинг справа
    int padY,                    ///< [in]  паддинг сверху
    int padH,                    ///< [in]  паддинг снизу
    Weight_fl16* input_weight,   ///< [in]  структура весов
    float* bias,                 ///< [in]  указатель на bias
    Tensor_fl16* output_tensor,  ///< [in]  структура для выходного тензора
    ConvFl16Config* config,      ///< [out] структура для запуска потайлового алгоритма
    bool add_bias_flag,          ///< [in]  флаг наличия bias
    uint16_t* start_adr,         ///< [in]  адрес начала локальной памяти
    int local_mem_size           ///< [in]  размер локальной памяти
);

/// Запуск потайловой обработки алгоритма свертки формата float16.
/// Версия с шагом по каналам входа.
void convolution_small_ker_2A2B2C_srcC_ver(
    Tensor_fl16* input_tensor,               ///< [in]  структура для входного тензора
    int group,                               ///< [in]  группа
    Weight_fl16* input_weight,               ///< [in]  структура весов
    float* bias,                             ///< [in]  указатель на bias
    Tensor_fl16* output_tensor,              ///< [out] структура для выходного тензора
    ConvFl16Config* config,                  ///< [in]  структура для запуска потайлового алгоритма
    Store_version st_ver,                    ///< [in]  версия постобработки
    bool add_bias_flag,                      ///< [in]  флаг наличия bias
    Conv_version conv_version = VER_4x4x128  ///< [in] версия ядра свертки
);

extern "C" void convolution_16x128_fl16_ver1(uint16_t* src, uint16_t* weight, int src_width_stride,
                                             int src_height_stride, int src_batch_stride, int weight_width_stride,
                                             int weight_height_stride, int kernelX, int kernelY, int stride_X_in_bytes,
                                             int stride_Y_in_bytes, int loop_srcC_kerX_iter_count,
                                             int32_t* tic_counter,
                                             int32_t* instr_counter);  // srcC / 8 * kerX

extern "C" void convolution_16x128_fl16_ver1_with_offset(uint16_t* src, uint16_t* weight, int src_width_stride,
                                                         int src_height_stride, int src_batch_stride,
                                                         int weight_width_stride, int weight_height_stride,
                                                         int kernelX, int kernelY, int stride_X_in_bytes,
                                                         int stride_Y_in_bytes,
                                                         int loop_srcC_kerX_iter_count);  // srcC / 8 * kerX

extern "C" void convolution_32x64_fl16_ver1(uint16_t* src, uint16_t* weight, int src_width_stride,
                                            int src_height_stride, int src_batch_stride, int weight_width_stride,
                                            int weight_height_stride, int kernelX, int kernelY, int stride_X_in_bytes,
                                            int stride_Y_in_bytes,
                                            int loop_srcC_kerX_iter_count);  // srcC / 4 * kerX

extern "C" void convolution_8x8x32_fl16_ver1(uint16_t* src, uint16_t* weight, int src_width_stride,
                                             int src_height_stride, int src_batch_stride, int weight_width_stride,
                                             int weight_height_stride, int kernelX, int kernelY, int stride_X_in_bytes,
                                             int stride_Y_in_bytes,
                                             int loop_srcC_kerX_iter_count);  // srcC / 4 * kerX

extern "C" void convolution_8x8x32_fl16_ver1_with_offset(uint16_t* src, uint16_t* weight, int src_width_stride,
                                                         int src_height_stride, int src_batch_stride,
                                                         int weight_width_stride, int weight_height_stride,
                                                         int kernelX, int kernelY, int stride_X_in_bytes,
                                                         int stride_Y_in_bytes,
                                                         int loop_srcC_kerX_iter_count);  // srcC / 4 * kerX

extern "C" void convolution_4x4x128_fl16_ver1(uint16_t* src, uint16_t* weight, int src_width_stride,
                                              int src_height_stride, int src_batch_stride, int weight_width_stride,
                                              int weight_height_stride, int kernelX, int kernelY,
                                              int stride_X_in_bytes, int stride_Y_in_bytes,
                                              int loop_srcC_kerX_iter_count);  // srcC / 8 * kerX

extern "C" void convolution_4x4x128_fl16_ver1_with_offset(uint16_t* src, uint16_t* weight, int src_width_stride,
                                                          int src_height_stride, int src_batch_stride,
                                                          int weight_width_stride, int weight_height_stride,
                                                          int loop_srcC,  // srcC / 8
                                                          int kernelY, int stride_X_in_bytes, int stride_Y_in_bytes,
                                                          int loop_srcC_kerX_iter_count);  // srcC / 8 * kerX

extern "C" void convolution_4x8x64_fl16_ver1(uint16_t* src, uint16_t* weight, int src_width_stride,
                                             int src_height_stride, int src_batch_stride, int weight_width_stride,
                                             int weight_height_stride, int kernelX, int kernelY, int stride_X_in_bytes,
                                             int stride_Y_in_bytes,
                                             int loop_srcC_kerX_iter_count);  // srcC / 4 * kerX

extern "C" void convolution_4x8x64_fl16_ver1_with_offset(uint16_t* src, uint16_t* weight, int src_width_stride,
                                                         int src_height_stride, int src_batch_stride,
                                                         int weight_width_stride, int weight_height_stride,
                                                         int loop_srcC,  // srcC / 8
                                                         int kernelY, int stride_X_in_bytes, int stride_Y_in_bytes,
                                                         int loop_srcC_kerX_iter_count);  // srcC / 4 * kerX

// страйды указаны в байтах
extern "C" void store_accums_16xX(uint16_t* dst, int width, int src_width_stride, int src_height_stride,
                                  int32_t* tic_counter, int32_t* instr_counter);
extern "C" void store_accums_16xX_relu(uint16_t* dst, int width, int src_width_stride, int src_height_stride,
                                       int32_t* tic_counter, int32_t* instr_counter);
extern "C" void store_accums_16xX_relu6(uint16_t* dst, int width, int src_width_stride, int src_height_stride,
                                        int32_t* tic_counter, int32_t* instr_counter);

// even_flag = 0 нечетное, иначе четное
extern "C" void store_accums_Xx128(uint16_t* last_line_dst, int jump_by_height, int even_h_flag, int width,
                                   int src_width_stride, int src_height_stride);

extern "C" void store_accums_Xx128_relu(uint16_t* last_line_dst, int jump_by_height, int even_h_flag, int width,
                                        int src_width_stride, int src_height_stride);

extern "C" void store_accums_Xx128_relu6(uint16_t* last_line_dst, int jump_by_height, int even_h_flag, int width,
                                         int src_width_stride, int src_height_stride);

// страйды указаны в байтах
extern "C" void store_accums_32xX(uint16_t* dst, int width, int src_width_stride, int src_height_stride);

// страйды указаны в байтах
extern "C" void store_accums_4x8x64(uint16_t* dst, int width, int src_width_stride, int src_height_stride);
extern "C" void store_accums_4x8x64_relu(uint16_t* dst, int width, int src_width_stride, int src_height_stride);
extern "C" void store_accums_4x8x64_relu6(uint16_t* dst, int width, int src_width_stride, int src_height_stride);

extern "C" void store_accums_8x8x32(uint16_t* dst, int width, int src_width_stride, int src_height_stride);
extern "C" void store_accums_8x8x32_relu(uint16_t* dst, int width, int src_width_stride, int src_height_stride);
extern "C" void store_accums_8x8x32_relu6(uint16_t* dst, int width, int src_width_stride, int src_height_stride);

// страйды указаны в байтах
extern "C" void store_accums_4Xx8Xx64(uint16_t* dst, int width, int src_width_stride, int src_height_stride,
                                      int block_w, int block_h, int jump_by_width, int jump_by_height);
extern "C" void store_accums_4Xx8Xx64_relu(uint16_t* dst, int width, int src_width_stride, int src_height_stride,
                                           int block_w, int block_h, int jump_by_width, int jump_by_height);
extern "C" void store_accums_4Xx8Xx64_relu6(uint16_t* dst, int width, int src_width_stride, int src_height_stride,
                                            int block_w, int block_h, int jump_by_width, int jump_by_height);

// страйды указаны в байтах
extern "C" void store_accums_8Xx8Xx32(uint16_t* dst, int width, int src_width_stride, int src_height_stride,
                                      int block_w, int block_h, int jump_by_width, int jump_by_height);
extern "C" void store_accums_8Xx8Xx32_relu(uint16_t* dst, int width, int src_width_stride, int src_height_stride,
                                           int block_w, int block_h, int jump_by_width, int jump_by_height);
extern "C" void store_accums_8Xx8Xx32_relu6(uint16_t* dst, int width, int src_width_stride, int src_height_stride,
                                            int block_w, int block_h, int jump_by_width, int jump_by_height);
// страйды указаны в байтах
extern "C" void store_accums_4x4x128(uint16_t* dst, int width, int src_width_stride, int src_height_stride);

extern "C" void store_accums_4x4x128_relu(uint16_t* dst, int width, int src_width_stride, int src_height_stride);
extern "C" void store_accums_4x4x128_relu6(uint16_t* dst, int width, int src_width_stride, int src_height_stride);

// страйды указаны в байтах
extern "C" void store_accums_4Xx4Xx128(uint16_t* dst, int width, int src_width_stride, int src_height_stride,
                                       int block_w, int block_h, int jump_by_height);
extern "C" void store_accums_4Xx4Xx128_relu(uint16_t* dst, int width, int src_width_stride, int src_height_stride,
                                            int block_w, int block_h, int jump_by_height);
extern "C" void store_accums_4Xx4Xx128_relu6(uint16_t* dst, int width, int src_width_stride, int src_height_stride,
                                             int block_w, int block_h, int jump_by_height);

// even_flag = 0 нечетное, иначе четное
extern "C" void store_accums_Xx64(uint16_t* last_line_dst, int jump_by_height_outside, int even_h_flag, int width,
                                  int src_width_stride, int src_height_stride, int jump_by_height_inside);

extern "C" void load_vec_regs(void* dst);
extern "C" void store_vec_regs(void* dst);
extern "C" void clear_accums();
extern "C" void add_bias_16x128(float* src);
extern "C" void add_bias_32x64(float* src);
extern "C" void add_bias_32(float* src);

void convolution_small_ker_4x4x128(uint16_t* src, int batch, int src_h, int src_w, int src_c, int kernel_y,
                                   int kernel_x, int stride_y, int stride_x, uint16_t* weight, float* bias,
                                   uint16_t* dst, int dst_h, int dst_w, int dst_c, bool bias_flag, int offset_A,
                                   Store_version st_ver, char** ptr_buf_tics);

void convolution_small_ker_4x8x64(uint16_t* src, int batch, int src_h, int src_w, int src_c, int kernel_y,
                                  int kernel_x, int stride_y, int stride_x, uint16_t* weight, float* bias,
                                  uint16_t* dst, int dst_h, int dst_w, int dst_c, bool bias_flag, int offset_A,
                                  Store_version st_ver);

void convolution_small_ker_8x8x32(uint16_t* src, int batch, int src_h, int src_w, int src_c, int kernel_y,
                                  int kernel_x, int stride_y, int stride_x, uint16_t* weight, float* bias,
                                  uint16_t* dst, int dst_h, int dst_w, int dst_c, bool bias_flag, int offset_A,
                                  Store_version st_ver);

void convolution_small_ker_16x128(uint16_t* src, int batch, int src_h, int src_w, int src_c, int kernel_y,
                                  int kernel_x, int stride_y, int stride_x, uint16_t* weight, float* bias,
                                  uint16_t* dst, int dst_h, int dst_w, int dst_c, bool bias_flag, int offset_A,
                                  Store_version st_ver, char** ptr_buf_tics);

void convolution_small_ker_32x64(uint16_t* src, int batch, int src_h, int src_w, int src_c, int kernel_y, int kernel_x,
                                 int stride_y, int stride_x, uint16_t* weight, float* bias, uint16_t* dst, int dst_h,
                                 int dst_w, int dst_c, bool bias_flag);

extern "C" void store_float32_block_4x4x128(float* data, int ch, int width_stride, int height_stride,
                                            int jump_by_width_0, int jump_by_width_1);

extern "C" void load_float32_block_4x4x128(float* data, int ch, int width_stride, int height_stride,
                                           int jump_by_width_0, int jump_by_width_1);

extern "C" void store_float32_block_4x8x64(float* data, int ch, int width_stride, int height_stride,
                                           int jump_by_width_0, int jump_by_width_1, int jump_by_height);

extern "C" void load_float32_block_4x8x64(float* data, int ch, int width_stride, int height_stride,
                                          int jump_by_width_0, int jump_by_width_1, int jump_by_height);

extern "C" void store_float32_block_8x8x32(float* data, int ch, int width_stride, int height_stride,
                                           int jump_by_width_0, int jump_by_width_1, int jump_by_height);
extern "C" void load_float32_block_8x8x32(float* data, int ch, int width_stride, int height_stride,
                                          int jump_by_width_0, int jump_by_width_1, int jump_by_height);

void convolution_small_ker_4x4x128_srcC_ver(uint16_t* src, int batch, int src_h, int src_w, int src_c, int kernel_y,
                                            int kernel_x, int stride_y, int stride_x, uint16_t* weight, float* bias,
                                            uint16_t* dst, int dst_h, int dst_w, int dst_c, bool bias_flag,
                                            int offset_A, Store_version st_ver, int load_format, int store_format,
                                            bool start_clear_acc, int tile_size_dstC);

void convolution_small_ker_4x8x64_srcC_ver(uint16_t* src, int batch, int src_h, int src_w, int src_c, int kernel_y,
                                           int kernel_x, int stride_y, int stride_x, uint16_t* weight, float* bias,
                                           uint16_t* dst, int dst_h, int dst_w, int dst_c, bool bias_flag,
                                           int offset_A, Store_version st_ver,
                                           int load_format,   // 0 - без загузки, 1 - загрузка float32
                                           int store_format,  // 2 - store float32, 1 - store float16, 0 - без выгрузки
                                           bool start_clear_acc, int tile_size_dstC);

void convolution_small_ker_8x8x32_srcC_ver(uint16_t* src, int batch, int src_h, int src_w, int src_c, int kernel_y,
                                           int kernel_x, int stride_y, int stride_x, uint16_t* weight, float* bias,
                                           uint16_t* dst, int dst_h, int dst_w, int dst_c, bool bias_flag,
                                           int offset_A, Store_version st_ver,
                                           int load_format,   // 0 - без загузки, 1 - загрузка float32
                                           int store_format,  // 2 - store float32, 1 - store float16, 0 - без выгрузки
                                           bool start_clear_acc, int tile_size_dstC);

void ref_convolution(const float* src, int batch, int srcH, int srcW, int srcC, int kernelY, int kernelX,
                     int dilationY, int dilationX, int strideY, int strideX, int padY, int padX, int padH, int padW,
                     int group, const float* weight, const float* bias, float* dst, int dstC);

#endif
