# Copyright 2025 RnD Center "ELVEES", JSC

import argparse
from array import array
import cmath
import math
import numpy as np
import scipy.signal as sc

parser = argparse.ArgumentParser(
    description="Makes files with data for DSP tests", formatter_class=argparse.RawTextHelpFormatter
)
parser.add_argument("-r0", "--rows0", type=int, help="Number of rows of the first array")
parser.add_argument("-c0", "--cols0", type=int, help="Number of cols of the first array")
parser.add_argument("-r1", "--rows1", type=int, help="Number of rows of the second array")
parser.add_argument("-c1", "--cols1", type=int, help="Number of cols of the second array")
parser.add_argument("-l", "--location", type=str, help="Location of the test_data files")
args = parser.parse_args()

# Scipy convolution (reference)
rows0 = args.rows0
cols0 = args.cols0
rows1 = args.rows1
cols1 = args.cols1
conv_rows = rows0 + rows1 - 1
conv_cols = cols0 + cols1 - 1

exp00 = np.transpose([[cmath.exp(2 * cmath.pi * k * 3j / rows0) for k in range(rows0)]]).tolist()
exp01 = [cmath.exp(2 * cmath.pi * k * 1j / cols0) for k in range(cols0)]
src0 = np.multiply(exp00, exp01).tolist()

exp10 = np.transpose([[cmath.exp(2 * cmath.pi * k * 2j / rows1) for k in range(rows1)]]).tolist()
exp11 = [cmath.exp(2 * cmath.pi * k * 4j / cols1) for k in range(cols1)]
src1 = np.multiply(exp10, exp11).tolist()

dst = sc.convolve2d(src0, src1)

# Extend matrices for DSP tests
conv_rows2 = 1 << math.ceil(math.log2(conv_rows))
conv_cols2 = 1 << math.ceil(math.log2(conv_cols))

ext_array_rows0 = np.zeros((conv_rows2 - rows0, cols0), dtype=complex)
ext_array_cols0 = np.zeros((conv_rows2, conv_cols2 - cols0), dtype=complex)
src0 = np.append(src0, ext_array_rows0, 0)
src0 = np.append(src0, ext_array_cols0, 1)

ext_array_rows1 = np.zeros((conv_rows2 - rows1, cols1), dtype=complex)
ext_array_cols1 = np.zeros((conv_rows2, conv_cols2 - cols1), dtype=complex)
src1 = np.append(src1, ext_array_rows1, 0)
src1 = np.append(src1, ext_array_cols1, 1)

# Make binary files with data for DSP tests
with open(f"{args.location}/input_data0.bin", "wb") as write_data:
    src0_array = []
    for row in src0:
        src0_array.extend([[row[i].imag, row[i].real] for i in range(len(row))])
    src0_array = array("f", np.array(src0_array).flatten())
    write_data.write(src0_array.tobytes())

with open(f"{args.location}/input_data1.bin", "wb") as write_data:
    src1_array = []
    for row in src1:
        src1_array.extend([[row[i].imag, row[i].real] for i in range(len(row))])
    src1_array = array("f", np.array(src1_array).flatten())
    write_data.write(src1_array.tobytes())

with open(f"{args.location}/output_data0.bin", "wb") as write_data:
    dst_array = []
    for row in dst:
        dst_array.extend([[row[i].imag, row[i].real] for i in range(len(row))])
    dst_array = array("f", np.array(dst_array).flatten())
    write_data.write(dst_array.tobytes())
