#!/usr/bin/env python3

"""
This script is a single channel full loopback for the AIR-T that tests the AGC functionality. It
creates a signal in software and transmits that signal out of the AIR-T and receives it back into
the corresponding receiver channel. The power is slowly swept in increasing levels until the AGC
engages. The result is a plot of TX Gain value vs RX Received Power which should level out as the
AGC engages.
"""

import numpy as np
from matplotlib import pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import SoapySDR
from SoapySDR import Device, SOAPY_SDR_TX, SOAPY_SDR_RX, SOAPY_SDR_CS16, SOAPY_SDR_CF32, errToStr, SOAPY_SDR_UNDERFLOW
import concurrent.futures
import threading
import time
import sys
import logging
from pathlib import Path


current_file = Path(__file__)

# Default Settings
TX_SETTINGS = dict(fs=15.625e6, rf=2405e6, buff_len=2**16, tx_amp_bypass=False)
RX_SETTINGS = dict(fs=15.625e6, lo=2400e6, buff_len=2**16, rx_path='lna')
POW_TOL_DB = 6  # RX received power must be in this window
MAX_CABLE_LOSS = 5  # If the measured cable loss is greater, error thrown


class SDRLogger(logging.Logger):
    """
    Logger for SDR performance tests that names logs by device serial and script,
    and routes INFO to stdout while warnings/errors go to stderr.
    """
    
    class ScreenLogHandler(logging.StreamHandler):
        """Split log output between stdout and stderr based on level."""
        def emit(self, record):
            self.stream = sys.__stderr__ if record.levelno >= logging.WARNING else sys.__stdout__
            super().emit(record)
            if self.stream and not self.stream.closed:
                self.stream.flush()
            self.stream = None

    def __enter__(self):
        """Return self for context-manager usage."""
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        """Log exceptions and close all handlers on context exit."""
        if exc_type:
            self.error(f"Script crashed: {exc_val}", exc_info=True)
        
        self.info("Closing logger and flushing handlers.")
        # Close all handlers (files, etc.) to ensure no data is lost
        for handler in self.handlers:
            handler.close()
            self.removeHandler(handler)

    def __init__(self, script_path=__file__):
        """Create a serial-tagged logger with file and console handlers."""
        # 1. Get hardware info
        try:
            results = SoapySDR.Device.enumerate()
            serial = results[0]['serial'] if results else "unknown_sdr"
        except Exception:
            serial = "error_sdr"

        # 2. Initialize the base logger with the script name
        script_path = Path(script_path)
        logger_name = script_path.stem
        super().__init__(logger_name)
        self.setLevel(logging.INFO)

        # 3. Setup File Logging (Auto-delete existing)
        log_file = Path(f"{serial}_{logger_name}.log")
        log_file.unlink(missing_ok=True)
        
        file_handler = logging.FileHandler(log_file)
        file_handler.setLevel(logging.INFO)

        # 4. Setup Screen Logging
        screen_handler = self.ScreenLogHandler()
        screen_handler.setLevel(logging.INFO)

        # 5. Add a shared formatter for readability
        formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(message)s')
        file_handler.setFormatter(formatter)
        screen_handler.setFormatter(formatter)

        self.addHandler(file_handler)
        self.addHandler(screen_handler)
    
    def log_val(self, key, val, sep=35):
        """Log a key/value pair in an aligned format."""
        self.info(f'{key:<{sep}}: {val}')


class PdfWriter:
    """
    Helper that writes Matplotlib figures to a per-run PDF report.
    """
    def __init__(self, script_path=__file__):
        """Create and open a PDF report file tied to the SDR serial."""
        # 1. Get hardware info
        try:
            results = SoapySDR.Device.enumerate()
            serial = results[0]['serial'] if results else "unknown_sdr"
        except Exception:
            serial = "error_sdr"
        script_path = Path(script_path)
        pdf_file = Path(f"{serial}_{script_path.stem}.pdf")
        pdf_file.unlink(missing_ok=True)
        self._filename = pdf_file
        self._filename.unlink(missing_ok=True)
        self._pp = PdfPages(self._filename)
    
    def __enter__(self):
        """Return self for context-manager usage."""
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        """Close the PDF on exit and emit a note if an error occurred."""
        self.close()
        # If there was an error, this prints it but doesn't stop the PDF from closing
        if exc_type:
            print(f"PDF closed during cleanup after error: {exc_val}")

    def add_section(self, section_title=''):
        """
        Write a title-only page to the report.
        :param section_title: str
        :return:
        """
        fig_handle, ax = plt.subplots(1, 1, figsize=(8.5, 11))
        ax.axis('off')
        plt.suptitle(f'{section_title}\n\n', fontsize=18)
        self._pp.savefig(fig_handle)
    
    def add_page(self, fig_handle):
        """
        Add a Matplotlib figure as a PDF page.
        :param fig_handle: matplotlib fig handle
        :return:
        """
        self._pp.savefig(fig_handle)
    
    def close(self):
        """
        Close the PDF writer so the file can be opened.
        :return:
        """
        self._pp.close()
        return 'Report written to: {}'.format(self._filename)
    
    @property
    def filename(self):
        return self._filename


class TransmitterBackgroundTask:
    def __init__(self, _sdr, _tx_stream):
        """Bind SDR and TX stream and prepare a stop event."""
        self.sdr = _sdr
        self.tx_stream = _tx_stream
        self.tx_stop_event = threading.Event()
        self.tx_task = None

    def _transmit_data_task(self, _buff_list, _samples_per_buffer):
        """Continuously write buffers until the stop event is set."""
        while not self.tx_stop_event.is_set():
            rc = self.sdr.writeStream(
                self.tx_stream, _buff_list, _samples_per_buffer)
            if rc.ret != _samples_per_buffer:
                # Log the error if the write fails
                break

    def repeat_buffer(self, buffer_list, samples_per_buffer, sleep_sec=0.5):
        """Start the background transmit loop and optionally warm up."""
        self.tx_stop_event.clear()
        self.sdr.activateStream(self.tx_stream)
        executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
        self.tx_task = executor.submit(
            self._transmit_data_task, buffer_list, samples_per_buffer
        )
        # Wait briefly to ensure the stream is warmed up and transmitting
        time.sleep(sleep_sec)

    def stop(self):
        """Stop TX thread, flush hardware buffers, and deactivate stream."""
        # Signal the loop to stop
        self.tx_stop_event.set()
        # Wait for the Python thread to finish moving data
        try:
            if self.tx_task:
                self.tx_task.result(timeout=2.0)
        except Exception as e:
            print(f"Thread exit error: {e}")
        # Wait for Hardware Flush        
        start_flush = time.time()
        while (time.time() - start_flush) < 2.0:
            # Poll the stream status for the Underflow flag
            status = self.sdr.readStreamStatus(self.tx_stream, timeoutUs=100000)
            if status.ret == 0 and (status.flags & SOAPY_SDR_UNDERFLOW):
                break # Buffer is empty!
        # Finally deactivate the stream
        self.sdr.deactivateStream(self.tx_stream)


def lookup_truth(x_val):
    """
    Interpolate expected RX power for a given TX gain value.

    Valid only for:
      TX_SETTINGS = dict(fs=15.625e6, rf=2405e6, buff_len=2**16, tx_amp_bypass=False)
      RX_SETTINGS = dict(fs=15.625e6, lo=2400e6, buff_len=2**16, rx_path='lna')
    """
    # Hard-coded truth values rounded to 2 decimal places, x-> Gain Setting, y-> Truth Received Power
    TRUTH_X = np.arange(-50, 0.25, 0.25, dtype=float)
    TRUTH_Y = np.array([
        -65.43, -65.20, -64.96, -64.74, -64.47, -64.21, -63.93, -63.74, -63.42, -63.21,
        -62.98, -62.70, -62.45, -62.24, -61.98, -61.74, -61.49, -61.27, -60.99, -60.73,
        -60.50, -60.27, -60.00, -59.75, -59.56, -59.27, -58.97, -58.72, -58.51, -58.30,
        -58.02, -57.73, -57.47, -57.27, -57.02, -56.76, -56.52, -56.21, -56.05, -55.78,
        -55.53, -55.27, -54.98, -54.77, -54.50, -54.26, -54.05, -53.78, -53.49, -53.25,
        -52.99, -52.78, -52.49, -52.27, -52.04, -51.74, -51.56, -51.30, -51.00, -50.81,
        -50.60, -50.25, -50.00, -49.78, -49.62, -49.33, -49.16, -48.80, -48.55, -48.29,
        -48.03, -47.77, -47.66, -47.29, -47.12, -46.85, -46.57, -46.34, -46.05, -45.81,
        -45.65, -45.42, -45.15, -44.92, -44.56, -44.32, -44.07, -43.80, -43.65, -43.31,
        -43.11, -42.90, -42.63, -42.38, -42.11, -41.90, -41.67, -41.49, -41.22, -40.96,
        -40.68, -40.42, -40.18, -39.90, -39.74, -39.46, -39.27, -38.99, -38.68, -38.50,
        -38.24, -37.91, -37.82, -37.51, -37.27, -37.05, -36.81, -36.53, -36.28, -36.05,
        -35.90, -35.60, -35.31, -35.15, -34.84, -34.62, -34.36, -34.14, -33.88, -33.69,
        -33.40, -33.08, -32.91, -32.60, -32.43, -32.14, -31.88, -31.67, -31.38, -31.10,
        -30.90, -30.65, -30.37, -30.09, -29.91, -29.60, -29.46, -29.10, -28.91, -28.71,
        -28.43, -28.13, -28.00, -28.00, -28.00, -28.00, -28.00, -28.00, -28.00, -28.00,
        -28.00, -28.00, -28.00, -28.00, -28.00, -28.00, -28.00, -28.00, -28.00, -28.00,
        -28.00, -28.00, -28.00, -28.00, -28.00, -28.00, -28.00, -28.00, -28.00, -28.00,
        -28.00, -28.00, -28.00, -28.00, -28.00, -28.00, -28.00, -28.00, -28.00, -28.00,
        -28.00, -28.00, -28.00, -28.00, -28.00, -28.00, -28.00, -28.00, -28.00, -28.00,
        -28.00
    ])
    return np.interp(x_val, TRUTH_X, TRUTH_Y)


def psd(sig: np.ndarray, fs: float, fcen: float = 0) -> tuple:
    """
    Compute the peak PSD and its frequency for a complex input signal.
    """
    nfft = sig.shape[-1]
    S0 = np.fft.fft(sig, nfft)
    S1 = np.fft.fftshift(S0, -1) / (2 * nfft)
    psd_seg = (S1 * S1.conjugate()).real / (fs / 1e6)  # FS / MHz
    if len(sig.shape) == 1:  # 1D array
        y = psd_seg
    else:  # 2D array
        if sig.shape[0] == 1:
            y = psd_seg.squeeze()  # remove unnecessary dimension
        else:
            y = np.mean(psd_seg, 0)  # Noncoherent integration to reduce dims
    x = (fcen + (np.arange(0, fs, fs / nfft) - (fs / 2) + (fs / nfft)))
    # Return the peak value and corresponding frequency
    return x[np.argmax(y)], 10*np.log10(np.max(y))


def tone(n, fcen, samp_rate, a=None, phi=0.285):
    """
    Generate an interleaved int16 IQ tone with integer-period length for repeatability.
    """
    period_len = samp_rate / fcen
    assert n % period_len == 0, 'Total samples not integer number of periods.' \
                                'Samples Per Period = {}'.format(period_len)
    if a is None:
        a = [20000, 20000]
    elif len(a) == 1:
        a = [a, a]
    # Make Tone Signal
    wt = np.array(2 * np.pi * fcen * np.arange(n) / samp_rate)
    sig_cplx = np.exp(1j * (wt + phi))
    sig_int16 = np.empty(2 * n, dtype=np.int16)
    sig_int16[::2] = a[0] * sig_cplx.real
    sig_int16[1::2] = a[1] * sig_cplx.imag

    return sig_int16


def tx2rx(_sdr, _tx_stream, _rx_stream, _tx_buff, _rx_buff, _tx_buff_len, _rx_buff_len,
          _rx_chan, _buff_dump=10):
    """
    Transmit a buffer, flush RX reads, then return the last RX buffer and AGC gain.
    """

    # Setup background task for transmitter, set the gain, and start transmitting
    tx_process = TransmitterBackgroundTask(_sdr, _tx_stream)
    tx_process.repeat_buffer([_tx_buff], _tx_buff_len)

    # Flush the RX buffers a few times. Final buffer is the one we will use
    for _ in range(_buff_dump+1):
        rc = _sdr.readStream(_rx_stream, [_rx_buff], _rx_buff_len)
        rx_gain = _sdr.getGain(SOAPY_SDR_RX, _rx_chan)
        if rc.ret != _rx_buff_len:
            raise IOError(f'Rx Error {rc.ret}: {errToStr(rc.ret)}')

    # Deactivate the RX stream
    _sdr.deactivateStream(_rx_stream)
    # Deactivate the RX stream (stop background transmit process)
    tx_process.stop()

    return _rx_buff, rx_gain


def main(tx_pars, rx_pars, tx_gain_vals, report, logger, test_num=''):
    """
    Run the AGC validation sweep for one TX/RX channel pair and return pass/fail.
    """
    # Generate signal that will be repeated
    bb_freq = tx_pars['fs'] / 8  # baseband frequency of tone
    lo_freq = tx_pars['rf'] - bb_freq  # LO freq to put tone at desired rf
    # Create AIR-T instance
    sdr = Device(dict(driver="SoapyAIRT"))
    sdr.writeSetting(SOAPY_SDR_RX, rx_pars['chan'], "rx_path", "bypass")
    sdr.writeSetting(SOAPY_SDR_TX, tx_pars['chan'], "tx_amp_bypass", "false")
        
    # Setup Transmitter
    sdr.setSampleRate(SOAPY_SDR_TX, tx_pars['chan'], tx_pars['fs'])
    sdr.setFrequency(SOAPY_SDR_TX, tx_pars['chan'], lo_freq)
    tx_stream = sdr.setupStream(SOAPY_SDR_TX, SOAPY_SDR_CS16, [tx_pars['chan']])
    # Create the tone to transmit
    tx_buff = tone(tx_pars['buff_len'], bb_freq, tx_pars['fs'])

    # Setup Receiver
    sdr.setSampleRate(SOAPY_SDR_RX, rx_pars['chan'], rx_pars['fs'])
    sdr.setFrequency(SOAPY_SDR_RX, rx_pars['chan'], rx_pars['lo'])
    sdr.setGainMode(SOAPY_SDR_RX, rx_pars['chan'], True)
    rx_stream = sdr.setupStream(SOAPY_SDR_RX, SOAPY_SDR_CF32, [rx_pars['chan']])
    rx_buff = np.zeros(rx_pars['buff_len'], np.complex64)

    # Measure cable loss for current loopback then subtract that from the measured power to
    # take into account the offset due to using different cables. We will take the cable loss and
    # offset the TX Gain setting to increase the power by the amount of loss in the cable.
    cal_tx_gain_val = tx_gain_vals[0]
    sdr.setGain(SOAPY_SDR_TX, tx_pars['chan'], cal_tx_gain_val)
    sig, current_agc_gain = tx2rx(sdr, tx_stream, rx_stream, tx_buff, rx_buff, tx_pars['buff_len'],
                                  rx_pars['buff_len'], rx_pars['chan'])
    curr_meas_freq, current_meas_pow  = psd(sig, rx_pars['fs'], rx_pars['lo'])
    # Ensure that the found peak is in the correct location and the AGC if not activated
    assert current_agc_gain == 0, "Calilbration failed, AGC was enabled for current TX Gain value"
    if np.abs(tx_pars['rf'] - curr_meas_freq) > 10e3:
        raise ValueError(
            f"Peak at {curr_meas_freq:,.0f} \u2260 {tx_pars['rf']:,.0f}. "
            f"Check Loopback Cable on TX{tx_pars['chan']} / RX{rx_pars['chan']}"
        )
    cable_loss_val = lookup_truth(cal_tx_gain_val) - current_meas_pow
    gain_min = sdr.getGainRange(SOAPY_SDR_TX, tx_pars['chan']).minimum()
    gain_max = sdr.getGainRange(SOAPY_SDR_TX, tx_pars['chan']).maximum()
    gain_step = sdr.getGainRange(SOAPY_SDR_TX, tx_pars['chan']).step()
    tx_gain_adjustment = np.round(cable_loss_val / gain_step) * gain_step 
    assert tx_gain_adjustment < MAX_CABLE_LOSS, \
        f"Cable loss ({cable_loss_val:.1f} dB) too large to activate AGC. " \
        f"Please use a shorter cable with less than {MAX_CABLE_LOSS:.1f} dB of loss."

    # Log the settings
    logger.log_val("RX AGC Validation Channel", rx_pars['chan'])
    logger.log_val("Rx Path Setting", sdr.readSetting(SoapySDR.SOAPY_SDR_RX, rx_pars['chan'], 'rx_path'))
    logger.log_val("RX Sample Rate", sdr.getSampleRate(SOAPY_SDR_RX, rx_pars['chan']))
    logger.log_val("RX Tune Frequency", sdr.getFrequency(SOAPY_SDR_RX, rx_pars['chan']))
    logger.log_val("RX Buffer Len", rx_pars['buff_len'])
    logger.log_val("TX Transmit Channel", tx_pars['chan'])
    logger.log_val("Tx Amp Bypass", sdr.readSetting(SoapySDR.SOAPY_SDR_TX, tx_pars['chan'], 'tx_amp_bypass'))
    logger.log_val("TX Sample Rate", sdr.getSampleRate(SOAPY_SDR_TX, rx_pars['chan']))
    logger.log_val("TX Tune Frequency", sdr.getFrequency(SOAPY_SDR_TX, rx_pars['chan']))
    logger.log_val("TX Buffer Len", tx_pars['buff_len'])
    logger.log_val('Loopback cable calibration value', f'{tx_gain_adjustment} dB')
    logger.log_val("Starting Transmit Power Sweep", f"TX{tx_pars['chan']} -> RX{rx_pars['chan']}")
    logger.info('| TX Gain (dB) | AGC Gain (dB) | Meas Power (dB) | Expected Power (dB) | Result |')
    logger.info('|--------------|---------------|-----------------|---------------------|--------|')
    format_str = '| {: >12.2f} | {: >13.2f} | {: >15.1f} | {: >19.1f} | {: >6s} |'

    # Setup and execute main loop power sweep
    tx_gain = []
    agc_gain = []
    meas_pow = []
    expected_pow = []
    meas_pow_err = []
    for tx_gain_val in tx_gain_vals:
        # Adjust the gain value to take the cable loss into account
        adjusted_tx_gain_val = tx_gain_val + tx_gain_adjustment
        expected_rx_val = lookup_truth(adjusted_tx_gain_val)
        # Don't do the test if the necessary TX gain is more than is available
        if not (gain_min <= adjusted_tx_gain_val <= gain_max):
            logger.info(format_str.format(current_tx_gain, float('nan'), float('nan'), float('nan'),
                                          'VOID'))
            continue
        # Set the current TX gain value and read it back
        sdr.setGain(SOAPY_SDR_TX, tx_pars['chan'], adjusted_tx_gain_val)
        current_tx_gain = sdr.getGain(SOAPY_SDR_TX, tx_pars['chan'])
        # Run the TX->RX loopback test (Transmits tone and receives it back)
        sig, current_agc_gain = tx2rx(sdr, tx_stream, rx_stream, tx_buff, rx_buff, 
                                      tx_pars['buff_len'], rx_pars['buff_len'], rx_pars['chan'])
        # Measure the power received from the loopback and ensure that peak is at the right freq
        curr_meas_freq, current_meas_pow  = psd(sig, rx_pars['fs'], rx_pars['lo'])
        if np.abs(tx_pars['rf'] - curr_meas_freq) > 10e3:
            raise ValueError(f"Peak at {curr_meas_freq:,.0f} \u2260 {tx_pars['rf']:,.0f}. "
                             f"Check Loopback Cable on TX{tx_pars['chan']} / RX{rx_pars['chan']}")
        # Calculate error
        curr_meas_pow_err = np.abs(expected_rx_val - current_meas_pow)
        # Log the result
        if curr_meas_pow_err < POW_TOL_DB:
            logger.info(format_str.format(current_tx_gain, current_agc_gain, current_meas_pow,
                                            expected_rx_val,'PASS'))
        else:
            logger.info(format_str.format(current_tx_gain, current_agc_gain, current_meas_pow,
                                          expected_rx_val, '<--- FAIL'))
        # Save measurements for this loop iteration
        tx_gain.append(current_tx_gain)
        agc_gain.append(current_agc_gain)
        meas_pow.append(current_meas_pow)
        expected_pow.append(expected_rx_val)
        meas_pow_err.append(curr_meas_pow_err)

    # Close the TX and RX streams
    sdr.closeStream(tx_stream)
    sdr.closeStream(rx_stream)

    if np.all(np.array(meas_pow_err) <= POW_TOL_DB):
        result_str = 'TEST PASS'
        result = True
    else:
        result_str = 'TEST FAIL'
        result = False

    logger.log_val('Result', result_str)
    # Create plot and add to report
    if report:
        fig, ax1 = plt.subplots(figsize=(11, 8))
        ax1.plot(tx_gain, meas_pow, 'g-', label='Measured Power')
        ax1.plot(tx_gain, np.array(expected_pow) + POW_TOL_DB, '--g', label='Measured Power Upper')
        ax1.plot(tx_gain, np.array(expected_pow) - POW_TOL_DB, '--g', label='Measured Power Lower')
        ax1.set_xlabel('TX Gain (dB)')
        ax1.set_ylabel('Measured Power Calibrated (dB)', color='g')
        ax2 = ax1.twinx()
        ax2.plot(tx_gain, agc_gain, 'b-', label='AGC Gain')
        ax2.set_ylabel('AGC Gain (dB)', color='b')
        plt.title(f"Test {test_num}, TX Chan {tx_pars['chan']}, RX Chan {rx_pars['chan']}, {result_str}")
        report.add_page(fig)
        logger.log_val('Figure added to pdf', report.filename)
    plt.close()   
    return result


if __name__ == '__main__':
    # Default Test Settings
    n_tests = 5
    chan_vals = [2, 3]
    # Define test points and lookup corresponding truth values.
    # Note that the TX gain values must be in the range of [-50, 0].
    tx_gain_values = np.arange(-30, 0, 1, dtype=float)
    tx_settings = TX_SETTINGS
    rx_settings = RX_SETTINGS

    start_time = time.time()
    results = []
    with SDRLogger() as logger, PdfWriter() as report:
        logger.log_val('Test Script', current_file)
        logger.info('Hardware Information:')
        for dev_par, dev_val in SoapySDR.Device.enumerate()[0].items():
            logger.log_val(dev_par, dev_val)
        for test_num in range(n_tests):
            logger.info('---------------------------------------------------------------------')
            logger.log_val('Test Number', test_num + 1)
            for chan in chan_vals:
                tx_settings['chan'] = rx_settings['chan'] = chan
                result = main(tx_settings, rx_settings, tx_gain_values, report, logger, test_num=test_num)
                results.append(result)
    minutes, seconds = divmod(time.time() - start_time, 60)
    logger.log_val('Elapsed time', f'{int(minutes)}m {seconds:.1f}s')
    if np.all(np.array(results)):
        logger.log_val('Overall All Result', 'PASSED')
    else:
        logger.log_val('Overall All Result', 'FAILED')
