// Copyright 2025 RnD Center "ELVEES", JSC

#ifndef MAT_MUL_WITH_DMA_FL32_H
#define MAT_MUL_WITH_DMA_FL32_H

#include <stdio.h>
#include <string.h>

#include <cassert>

#include "elcore50-matrix-lib/common.h"
#include "elcore50-matrix-lib/elcore50.h"

/// Структура для запуска потайлового алгоритма матричного умножения формата
/// float32
typedef struct MatMulFl32Config {
  float* buf_A[2];   ///< указатель на буферы тайловой обработки первой матрицы
  float* buf_B[2];   ///< указатель на буферы тайловой обработки второй матрицы
  float* buf_C[2];   ///< указатель на буферы тайловой обработки результирующей
                     ///< матрицы
  int buf_row0;      ///< к-во строк буфера первой матрицы
  int buf_row1col0;  ///< к-во строк буфера второй матрицы и столбцов первой
  int buf_col1;      ///< к-во столбцов буфера второй матрицы

  VDMAChain chain_A;     ///< dma цепочки для тайлов первой матрицы
  VDMAChain chain_B;     ///< dma цепочки для тайлов второй матрицы
  VDMAChain chain_ld_C;  ///< dma цепочки для тайлов загрузки результата
  VDMAChain chain_st_C;  ///< dma цепочки для тайлов выгрузки результата

  int len_chain_A;     ///< к-во dma цепочек для тайлов первой матрицы
  int len_chain_B;     ///< к-во dma цепочек для тайлов второй матрицы
  int len_chain_ld_C;  ///< к-во dma цепочек для загрузки тайлов результирующей
                       ///< матрицы
  int len_chain_st_C;  ///< к-во dma цепочек для выгрузки тайлов результирующей
                       ///< матрицы

} MatMulFl32Config;

/// Освобождение данных структуры запуска
void destroy_dma_chain_mat_mul_fl32(MatMulFl32Config* config  ///< [in] структура для запуска потайловой обработки
);

/// Подбор оптимальных размеров тайлов для умножения матриц формата float32
void size_selector_mat_mul_fl32(int M,            ///< [in]  к-во строк первой матрицы
                                int K,            ///< [in]  к-во столбцов первой матрицы
                                int N,            ///< [in]  к-во столбцов второй матрицы
                                int& buf_M,       ///< [out] к-во строк тайла первой матрицы
                                int& buf_K,       ///< [in]  к-во столбцов тайла первой матрицы
                                int& buf_N,       ///< [in]  к-во столбцов тайла второй матрицы
                                int count_buf_A,  ///< [in]  к-во используемых буферов (двубуферная
                                                  ///< или однобуферная схема)
                                int count_buf_B,  ///< [in]  к-во используемых буферов
                                int count_buf_C,  ///< [in]  к-во используемых буферов
                                int& offsetA,     ///< [out] смещение тайла первой матрицы
                                int& offsetB      ///< [out] смещение тайла второй матрицы
);

/// Инициализация и заполнение конфигурационной структуры для алгоритма
/// матричного умножения формата float32
void init_dma_chain_matmul_fl32(float* src0,              ///< [in]  указатель на первую матрицу
                                int row0,                 ///< [in]  к-во строк первой матрицы
                                int row1col0,             ///< [in]  к-во столбцов первой матрицы
                                float* src1,              ///< [in]  указатель на вторую матрицу
                                int col1,                 ///< [in]  к-во столбцов второй матрицы
                                float* dst,               ///< [in]  указатель на результирующую матрицу
                                int& offset_A,            ///< [out] смещение тайла первой матрицы
                                int& offset_B,            ///< [out] смещение тайла второй матрицы
                                MatMulFl32Config* config  ///< [out] структура для запуска потайловой обработки
);

/// Запуск потайловой обработки алгоритма матричного умножения формата float32
void run_matmul_fl32(float* src0,              ///< [in]  указатель на первую матрицу
                     int row0,                 ///< [in]  к-во строк первой матрицы
                     int row1col0,             ///< [in]  к-во столбцов первой матрицы
                     float* src1,              ///< [in]  указатель на вторую матрицу
                     int col1,                 ///< [in]  к-во столбцов второй матрицы
                     float* dst,               ///< [out]  указатель на результирующую матрицу
                     int offset_A,             ///< [in] смещение тайла первой матрицы
                     int offset_B,             ///< [in] смещение тайла второй матрицы
                     MatMulFl32Config* config  ///< [in] структура для запуска потайловой обработки
);

extern "C" void mm_v0_vliw_1_sub_matrix_pre_load_real_out_offset(float* in1, int row, int col0row1, float* in2,
                                                                 int col1, float* out, int x_offset, int y_offset,
                                                                 int real_row, int real_col, int* tics, int* instr,
                                                                 int offsetA, int offsetB);

extern "C" void mm_v0_vliw_1_sub_matrix_pre_load_real_out_offset_local_mem(float* in1, int row, int col0row1,
                                                                           float* in2, int col1, float* out,
                                                                           int x_offset, int y_offset, int real_row,
                                                                           int real_col, int* tics, int* instr,
                                                                           int offsetA, int offsetB);

extern "C" void mm_16x64_local_mem(float* in1, int row, int col0row1, float* in2, int col1, float* out, int flag,
                                   int real_col, int* tics, int* instr, int offsetA, int offsetB);

extern "C" void mm_32x32_local_mem(float* in1, int row, int col0row1, float* in2, int col1, float* out, int flag,
                                   int real_col, int* tics, int* instr, int offsetA, int offsetB);

#endif
