# -*- coding: utf-8 -*-
import numpy as np
import pandas as pd


def mne_channel_add(
    raw, channel, channel_type=None, channel_name=None, sync_index_raw=0, sync_index_channel=0
):
    """Add channel as array to MNE.

    Add a channel to a mne's Raw m/eeg file. It will basically synchronize the channel to the eeg data
    following a particular index and add it.

    Parameters
    ----------
    raw : mne.io.Raw
        Raw EEG data from MNE.
    channel : list or array
        The signal to be added.
    channel_type : str
        Channel type. Currently supported fields are 'ecg', 'bio', 'stim', 'eog', 'misc', 'seeg',
        'ecog', 'mag', 'eeg', 'ref_meg', 'grad', 'emg', 'hbr' or 'hbo'.
    channel_name : str
        Desired channel name.
    sync_index_raw : int or list
        An index (e.g., the onset of the same event marked in the same signal), in the raw data, by
        which to align the two inputs. This can be used in case the EEG data and the channel to add
        do not have the same onsets and must be aligned through some common event.
    sync_index_channel : int or list
        An index (e.g., the onset of the same event marked in the same signal), in the channel to add,
        by which to align the two inputs. This can be used in case the EEG data and the channel to add
        do not have the same onsets and must be aligned through some common event.

    Returns
    ----------
    mne.io.Raw
        Raw data in FIF format.

    Example
    ----------
    >>> import neurokit2 as nk
    >>> import mne
    >>>
    >>> raw = nk.mne_data("filt-0-40_raw")
    >>> ecg = nk.ecg_simulate(length=50000)
    >>>
    >>> # Let the 42nd sample point in the EEG signal correspond to the 333rd point in the ECG
    >>> event_index_in_eeg = 42
    >>> event_index_in_ecg = 333
    >>>
    >>> raw = nk.mne_channel_add(raw,
    ...                          ecg,
    ...                          sync_index_raw=event_index_in_eeg,
    ...                          sync_index_channel=event_index_in_ecg,
    ...                          channel_type="ecg")  # doctest: +SKIP

    """
    # Try loading mne
    try:
        import mne
    except ImportError:
        raise ImportError(
            "NeuroKit error: eeg_channel_add(): the 'mne' module is required for this function to run. ",
            "Please install it first (`pip install mne`).",
        )

    if channel_name is None:
        if isinstance(channel, pd.Series):
            if channel.name is not None:
                channel_name = channel.name
            else:
                channel_name = "Added_Channel"
        else:
            channel_name = "Added_Channel"

    # Compute the distance between the two signals
    diff = sync_index_channel - sync_index_raw

    # Pre-empt the channel with nans if shorter or crop if longer
    if diff > 0:
        channel = list(channel)[diff : len(channel)]
        channel = channel + [np.nan] * diff
    if diff < 0:
        channel = [np.nan] * abs(diff) + list(channel)

    # Extend the channel with nans if shorter or crop if longer
    if len(channel) < len(raw):
        channel = list(channel) + [np.nan] * (len(raw) - len(channel))
    else:
        # Crop to fit the raw data
        channel = list(channel)[0 : len(raw)]

    old_verbosity_level = mne.set_log_level(verbose="WARNING", return_old_level=True)

    # Create RawArray
    info = mne.create_info([channel_name], raw.info["sfreq"], ch_types=channel_type)
    channel = mne.io.RawArray([channel], info)

    # Add channel
    raw = raw.copy()
    raw.add_channels([channel], force_update_info=True)  # In-place

    # Restore old verbosity level
    mne.set_log_level(old_verbosity_level)

    return raw
