# Copyright 2024 RnD Center "ELVEES", JSC

import glob
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy import fft, signal
import struct
import subprocess


def test_signal_lib_fft(request):
    tests_names = [
        "test_FLOAT_FFT_1D_DDR",
        "test_FLOAT_FFT_1D_XYRAM",
        "test_FLOAT_FFT_2D_DDR",
        "test_FLOAT_FFT_2D_XYRAM",
        "test_FLOAT_IFFT_1D_DDR",
        "test_FLOAT_IFFT_1D_XYRAM",
        "test_FLOAT_IFFT_2D_DDR",
        "test_FLOAT_IFFT_2D_XYRAM",
        "test_HFLOAT_FFT_1D_DDR",
        "test_HFLOAT_FFT_1D_XYRAM",
        "test_HFLOAT_FFT_2D_DDR",
        "test_HFLOAT_FFT_2D_XYRAM",
        "test_HFLOAT_IFFT_1D_DDR",
        "test_HFLOAT_IFFT_1D_XYRAM",
        "test_HFLOAT_IFFT_2D_DDR",
        "test_HFLOAT_IFFT_2D_XYRAM",
        "test_FRACTIONAL_FFT_1D_DDR",
        "test_FRACTIONAL_FFT_1D_XYRAM",
        "test_FRACTIONAL_FFT_2D_DDR",
        "test_FRACTIONAL_FFT_2D_XYRAM",
        "test_FRACTIONAL_IFFT_1D_DDR",
        "test_FRACTIONAL_IFFT_1D_XYRAM",
        "test_FRACTIONAL_IFFT_2D_DDR",
        "test_FRACTIONAL_IFFT_2D_XYRAM",
    ]

    for name in tests_names:
        binary_path = os.path.join(os.path.dirname(__file__) + "/elf/", name + ".elf")
        subprocess.run(
            ["elcorecl-run", "-e", binary_path],
            check=True,
            timeout=100,
        )

    fig, axs = plt.subplots(nrows=3, ncols=4, figsize=(42, 13))
    fig.suptitle("FFT elcore50 performance")

    dsp_freq = (
        int(
            subprocess.run(
                ["cat", "/sys/kernel/debug/clk/dsp0_clk/clk_rate"],
                capture_output=True,
                text=True,
                check=True,
            ).stdout
        )
        * 1e-6
    )

    for i, test_name in enumerate(tests_names):
        plt_float = pd.read_csv(f"{test_name}.csv", index_col=0)
        time_plt_float = plt_float["TIC_CNTR"] / dsp_freq
        axs[i // 8, (i // 2) % 4].plot(time_plt_float, label="XYRAM" if (i % 2) else "DDR")
        if i % 2:
            axs[i // 8, (i // 2) % 4].grid()
        axs[i // 8, (i // 2) % 4].set_yscale("log")

    axs[0, 0].set_ylabel("FLOAT")
    axs[1, 0].set_ylabel("HFLOAT")
    axs[2, 0].set_ylabel("FRACTIONAL")

    axs[0, 0].set_title("FFT 1D")
    axs[0, 1].set_title("FFT 2D")
    axs[0, 2].set_title("IFFT 1D")
    axs[0, 3].set_title("IFFT 2D")

    axs[0, 3].set_ylabel("time, \u03BCs")
    axs[0, 3].set_xlabel("(I)FFT size, complex elements")
    axs[0, 3].legend()

    artifacts_path = request.config.getoption("--artifacts-path")
    fig.savefig(f"{artifacts_path}/elcore50_fft_performance.png")
    plt.close()

    # Move .csv files to artifacts
    files_to_move = glob.glob("test_*FFT*.csv")
    for file in files_to_move:
        subprocess.run(["mv", file, f"{artifacts_path}/{file}"], check=True, timeout=10)


def test_signal_lib_conv_1d(request):
    data_locs = ["DDR", "XYRAM"]
    binary_path = os.path.join(os.path.dirname(__file__) + "/elf/", "test_conv_1d.elf")
    for mem in data_locs:
        subprocess.run(
            ["elcorecl-run", "-e", binary_path, "--", "-m", f"{mem}"],
            check=True,
            timeout=300,
        )
    # Move .csv files to artifacts
    artifacts_path = request.config.getoption("--artifacts-path")
    files_to_move = glob.glob("test_conv_1d_*.csv")
    for file in files_to_move:
        subprocess.run(["mv", file, f"{artifacts_path}/{file}"], check=True, timeout=10)


def test_signal_lib_conv_2d():
    data_locs = ["DDR", "XYRAM"]
    matrices_sizes = [
        ["16", "16", "16", "16"],
        ["16", "32", "32", "64"],
        ["256", "128", "32", "16"],
        ["256", "128", "128", "256"],
    ]
    gen_script_path = os.path.join(
        os.path.dirname(__file__) + "/test_conv/", "gen_conv_test_data.py"
    )
    test_data_path = os.path.dirname(__file__) + "/test_conv/test_data"
    os.makedirs(test_data_path, exist_ok=True)
    binary_path = os.path.join(os.path.dirname(__file__) + "/elf/", "test_conv_2d.elf")
    for size in matrices_sizes:
        subprocess.run(
            [
                ".venv/bin/python",
                gen_script_path,
                "-r0",
                size[0],
                "-c0",
                size[1],
                "-r1",
                size[2],
                "-c1",
                size[3],
                "-l",
                test_data_path,
            ],
            check=True,
            timeout=600,
        )
        for mem in data_locs:
            subprocess.run(
                [
                    "elcorecl-run",
                    "-e",
                    binary_path,
                    "--",
                    "-i",
                    test_data_path + "/input_data0.bin",
                    test_data_path + "/input_data1.bin",
                    "-o",
                    test_data_path + "/output_data0.bin",
                    "-m",
                    f"{mem}",
                    "-r0",
                    size[0],
                    "-c0",
                    size[1],
                    "-r1",
                    size[2],
                    "-c1",
                    size[3],
                ],
                check=True,
                timeout=300,
            )
    subprocess.run(
        ["rm", "-rf", test_data_path],
        check=True,
        timeout=100,
    )


def test_signal_lib_fir_filter(request):
    data_locs = ["XYRAM", "DDR"]
    windows = ["hamming", "nuttall", "blackman", "blackmanharris", "hann", "lanczos", "chebwin"]
    filter_types = ["lowpass", "highpass", "bandpass", "bandstop"]
    filter_size = "173"
    w0 = "0.1"
    w1 = "0.9"
    signal_size = "512"
    param = "100"
    gen_script_path = os.path.join(os.path.dirname(__file__) + "/test_filters/", "gen_test_data.py")
    test_data_path = os.path.dirname(__file__) + "/test_filters/test_data"
    os.makedirs(test_data_path, exist_ok=True)
    binary_path = os.path.join(os.path.dirname(__file__) + "/elf/", "test_fir_filter.elf")
    for win in windows:
        gen_command = [
            ".venv/bin/python",
            gen_script_path,
            "-w",
            win,
            "-t",
            filter_size,
            "-w0",
            w0,
            "-w1",
            w1,
            "-s",
            signal_size,
            "-l",
            test_data_path,
        ]
        if win == "chebwin":
            gen_command.extend(["-p", param])
        subprocess.run(
            gen_command,
            check=True,
            timeout=120,
        )
        for mem in data_locs:
            for type in filter_types:
                subprocess.run(
                    [
                        "elcorecl-run",
                        "-e",
                        binary_path,
                        "--",
                        "-i",
                        f"{test_data_path}/signal_with_noise.bin",
                        "-o",
                        f"{test_data_path}/h_{type}.bin",
                        f"{test_data_path}/signal_after_{type}.bin",
                        "-m",
                        mem,
                        "-w",
                        win,
                        "-p",
                        param,
                        "-t",
                        filter_size,
                        "-f",
                        type.upper(),
                        "-w0",
                        w0,
                        "-w1",
                        w1,
                        "-s",
                        signal_size,
                    ],
                    check=True,
                    timeout=300,
                )

    plot_names = [
        "Signal",
        "Signal spectrum",
        "Signal with noise",
        "Signal with noise spectrum",
        "FIR impulse response",
        "FIR frequency response",
        "Filtered signal",
        "Filtered signal spectrum",
    ]

    with open(f"{test_data_path}/clear_signal.bin", "rb") as file:
        fileContent = file.read()
    x = list(struct.unpack("f" * (len(fileContent) // 4), fileContent))
    with open(f"{test_data_path}/signal_with_noise.bin", "rb") as file:
        fileContent = file.read()
    xn = list(struct.unpack("f" * (len(fileContent) // 4), fileContent))

    win = windows[1]
    filt = "lowpass"
    h_csv = pd.read_csv(f"test_fir_filter_{win}_{filt.upper()}_XYRAM_h.csv")
    h = h_csv["h_func"]
    y_csv = pd.read_csv(f"test_fir_filter_{win}_{filt.upper()}_XYRAM_signal.csv")
    y = y_csv["signal"]
    lst_sig = [x, xn[1::2], h, y]

    fig = plt.figure(figsize=(16, 9))
    fig.suptitle("Lowpass FIR Filter Results")

    for i, sig in enumerate(lst_sig):
        plt.subplot(2, 4, i + 1)
        plt.plot(sig, color="C" + str(i))
        plt.minorticks_on()
        plt.grid(which="major")
        plt.grid(which="minor", linestyle=":")
        plt.title(plot_names[2 * i])
        plt.xlim([0, len(sig)])

        plt.subplot(2, 4, i + 5)
        if i == 2:
            _, resp = signal.freqz(sig)
            resp = np.abs(resp) + 10 ** (-10)
            k = np.linspace(0, 1, len(resp))
            plt.plot(k, 20 * np.log10(resp), color="C" + str(i))
            plt.xlim([0, 1])
            plt.xticks(np.arange(0, 1.1, 0.1))
        else:
            clc_fft = fft.fft(sig)
            clc_fft = 20 * np.log10(10e-11 + clc_fft / np.max(clc_fft))
            plt.plot(np.real(clc_fft), color="C" + str(i))
            plt.xlim([0, len(sig) // 2 + 1])
            plt.xticks(np.arange(0, len(sig) // 2 + 1, 50))
        plt.minorticks_on()
        plt.grid(which="major")
        plt.grid(which="minor", linestyle=":")
        plt.title(plot_names[2 * i + 1])

    plt.tight_layout()
    artifacts_path = request.config.getoption("--artifacts-path")
    fig.savefig(f"{artifacts_path}/{filt}_FIR_filter_{win}.png")
    plt.close()

    subprocess.run(
        ["rm", "-rf", test_data_path],
        check=True,
        timeout=100,
    )
