# -*- coding: utf-8 -*-
""" Run a test of the result unit together with thresholding and statistics.
"""

# Copyright 2018 Zurich Instruments AG

from __future__ import print_function
import time
import textwrap
import numpy as np

import zhinst.utils

from .common import initialize_device, acquisition_poll
from .common import ResultLoggingSource


def run_example(device_id, threshold=500, result_length=1500, num_averages=1, do_plot=True):
    """ Run a test of the result unit together with thresholding and statistics.

    The example applies a simple square wave to the instrument using the AWG.
    The integration functions use the full length of the integrators, and each
    integration function is basically just a constant value through the entire
    integration window, with different values for the different channels. We
    then sweep the starting point of the integration in relation to the pulse
    generated by the AWG. Initially, the integrators will not see the pulse at
    all, so the result will be zero. Then, as we gradually get more and more
    overlap of the integration function and the pulse, we will see a ramp up
    until a point in time where the integration window is completely within the
    pulse. Then, for larger delays we have the reverse process. We configure a
    fixed threshold for all channels and then we show how the threshold output
    toggles as the integration result goes above the threshold. We also read out
    the results from the statistics unit and show that in a table.

    Requirements:

      - Connect signal output 1 to signal input 1.
      - Connect signal output 2 to signal input 2.

    Arguments:

      device_id (str): The ID of the device to run the example with. For
        example, `dev2006` or `uhf-dev2006`.

      threshold (double): Quantization threshold.

      result_length (int): Number of measurements.

      num_averages (int): Number of averages per measurement.

      do_plot (bool, optional): Specify whether to plot the polled data.

    Returns:

      result_data (list of dicts): Transformation and thresholding results.

      statistics (list of dicts): Quantization statistics.

    """
    apilevel_example = 6  # The API level supported by this example.
    # Call a zhinst utility function that returns:
    # - an API session `daq` in order to communicate with devices via the data server.
    # - the device ID string that specifies the device branch in the server's node hierarchy.
    # - the device's discovery properties.
    required_devtype = 'UHFQA'
    required_options = None
    daq, device, _ = zhinst.utils.create_api_session(device_id, apilevel_example,
                                                     required_devtype=required_devtype,
                                                     required_options=required_options)

    # Perform initialization for UHFQA examples
    initialize_device(daq, device)

    # Configure AWG
    awg_program = textwrap.dedent("""\
    const RATE = 0;
    const FS = 1.8e9*pow(2, -RATE);
    const F_RES = 1.6e6;
    const LENGTH = 3.0e-6;
    const N = floor(LENGTH*FS);

    wave w = join(zeros(N), ones(N), zeros(N));

    setTrigger(AWG_INTEGRATION_ARM);
    var loop_cnt = getUserReg(0);
    var avg_cnt = getUserReg(1);
    var wait_delta = getUserReg(2);

    repeat (avg_cnt) {
        var wait_time = 0;

        repeat(loop_cnt) {
            wait_time = wait_time + wait_delta;
            setTrigger(AWG_INTEGRATION_ARM);
            playWave(w, w, RATE);
            wait(wait_time);
            setTrigger(AWG_INTEGRATION_TRIGGER + AWG_INTEGRATION_ARM);
            setTrigger(AWG_INTEGRATION_ARM);
            waitWave();
            wait(1000);
        }
    }

    setTrigger(0);
    """)

    # Create an instance of the AWG module
    awgModule = daq.awgModule()
    awgModule.set('device', device)
    awgModule.set('index', 0)
    awgModule.execute()

    # Transfer the AWG sequence program. Compilation starts automatically.
    awgModule.set('compiler/sourcestring', awg_program)
    while awgModule.getInt('compiler/status') == -1:
        time.sleep(0.1)

    # Ensure that compilation was successful
    assert awgModule.getInt('compiler/status') != 1

    # Configure AWG program from registers
    daq.setDouble('/{:s}/awgs/0/userregs/0'.format(device), result_length)
    daq.setDouble('/{:s}/awgs/0/userregs/1'.format(device), num_averages)
    daq.setDouble('/{:s}/awgs/0/userregs/2'.format(device), 1)

    # Configuration of weighted integration
    channels = [0, 1, 2, 3, 4, 5, 6, 7]
    weights = np.linspace(1, 0.1, 10)
    integration_length = 4096
    for i, ch in enumerate(channels):
        w = weights[i] * np.ones(integration_length)
        daq.setVector('/{:s}/qas/0/integration/weights/{}/real'.format(device, ch), w)
        daq.setVector('/{:s}/qas/0/integration/weights/{}/imag'.format(device, ch), w)

    daq.setInt('/{:s}/qas/0/integration/length'.format(device), integration_length)
    daq.setInt('/{:s}/qas/0/integration/mode'.format(device), 0)
    daq.setInt('/{:s}/qas/0/delay'.format(device), 0)

    # Enable statistics
    daq.setInt('/{:s}/qas/0/result/statistics/length'.format(device), result_length)
    daq.setInt('/{:s}/qas/0/result/statistics/reset'.format(device), 1)
    daq.setInt('/{:s}/qas/0/result/statistics/enable'.format(device), 1)

    # Configure thresholds
    for ch in channels:
        daq.setDouble('/{:s}/qas/0/thresholds/{:d}/level'.format(device, ch), threshold)

    # Configure the result unit
    daq.setInt('/{:s}/qas/0/result/length'.format(device), result_length)
    daq.setInt('/{:s}/qas/0/result/averages'.format(device), num_averages)

    # Subscribe to result waves
    paths = []
    for ch in channels:
        path = '/{:s}/qas/0/result/data/{:d}/wave'.format(device, ch)
        paths.append(path)
    daq.subscribe(paths)

    result_data = {}
    statistics = []
    for result_source in (ResultLoggingSource.TRANS, ResultLoggingSource.THRES):
        daq.setInt('/{:s}/qas/0/result/source'.format(device), result_source)

        # Now we're ready for readout. Enable result unit and start acquisition.
        daq.setInt('/{:s}/qas/0/result/reset'.format(device), 1)
        daq.setInt('/{:s}/qas/0/result/enable'.format(device), 1)
        daq.sync()

        # Arm the device
        daq.asyncSetInt('/{:s}/awgs/0/single'.format(device), 1)
        daq.syncSetInt('/{:s}/awgs/0/enable'.format(device), 1)

        # Perform acquisition
        print('Acquiring data for {!r}...'.format(result_source))
        result_data[result_source] = acquisition_poll(daq, paths, result_length)
        print('Done.')

        # Stop result unit
        daq.setInt('/{:s}/qas/0/result/enable'.format(device), 0)

        # Obtain statistics
        if result_source == ResultLoggingSource.TRANS:
            for ch in channels:
                num_ones = daq.getInt('/{:s}/qas/0/result/statistics/data/{:d}/ones'.format(device, ch))
                num_flips = daq.getInt('/{:s}/qas/0/result/statistics/data/{:d}/flips'.format(device, ch))
                statistics.append({'ones': num_ones, 'flips': num_flips})

    # Unsubscribe
    daq.unsubscribe(paths)

    print('Statistics:')
    print('{:15s}  {:>10s} {:>10s}'.format('Readout channel', '# ones', '# flips'))
    for i, ch in enumerate(channels):
        print('{:15d}  {:10d} {:10d}'.format(ch, statistics[i]['ones'], statistics[i]['flips']))

    if do_plot:
        import matplotlib.pyplot as plt
        fig, ax = plt.subplots(ncols=2, figsize=(12, 4), sharex=True)
        ax[0].set_title('Transformation unit')
        ax[0].set_ylabel('Amplitude (a.u.)')
        ax[0].set_xlabel('Measurement (#)')
        ax[0].axhline(threshold, color='k', linestyle='--')
        for path, samples in result_data[ResultLoggingSource.TRANS].items():
            ax[0].plot(samples, label='{}'.format(path))
        ax[1].set_title('Thresholding unit')
        ax[1].set_ylabel('Amplitude (a.u.)')
        ax[1].set_xlabel('Measurement (#)')
        for path, samples in result_data[ResultLoggingSource.THRES].items():
            ax[1].plot(samples, label='{}'.format(path))
        fig.set_tight_layout(True)
        plt.show()

    return result_data, statistics
