
import logging
import contextvars
from contextlib import contextmanager

# test_eeg_checkset()
import numpy as np
import os

logger = logging.getLogger(__name__)

__all__ = ['eeg_checkset', 'strict_mode']

# Context variable for strict mode (default True)
_strict_mode_var = contextvars.ContextVar('strict_mode', default=True)

class DummyException(Exception):
    """Exception that should never be raised, used to disable exception handling in strict mode"""
    pass

@contextmanager
def strict_mode(enabled: bool):
    """
    Context manager to control strict mode for eeg_checkset.
    
    Args:
        enabled (bool): If True, exceptions will propagate (strict mode).
                       If False, exceptions will be caught and handled gracefully.
    
    Usage:
        with strict_mode(False):
            EEG = eeg_checkset(EEG)  # Will catch and handle exceptions
    """
    token = _strict_mode_var.set(enabled)
    try:
        yield
    finally:
        _strict_mode_var.reset(token)


def eeg_checkset(EEG, load_data=True):
    # Get the exception type based on strict mode
    # In strict mode (True), we catch DummyException (never raised) so exceptions propagate
    # In non-strict mode (False), we catch Exception and handle gracefully
    exception_type = DummyException if _strict_mode_var.get() else Exception
        

    # convert EEG['nbchan] to integer
    if 'nbchan' in EEG:
        EEG['nbchan'] = int(EEG['nbchan'])
    else:
        EEG['nbchan'] = EEG['data'].shape[0]
    if 'pnts' in EEG:
        EEG['pnts'] = int(EEG['pnts'])
    else:
        EEG['pnts'] = EEG['data'].shape[1]
    if 'trials' in EEG:
        EEG['trials'] = int(EEG['trials'])
    else:
        if EEG['data'].ndim == 3:
            EEG['trials'] = EEG['data'].shape[2]
        else:
            EEG['trials'] = 1
            
    if 'event' in EEG:
        if isinstance(EEG['event'], dict):
            EEG['event'] = [EEG['event']]
    else:
        EEG['event'] = []
    if isinstance(EEG['event'], list):
        EEG['event'] = np.asarray(EEG['event'], dtype=object)
            
    if 'chanlocs' in EEG:
        if isinstance(EEG['chanlocs'], dict):
            EEG['chanlocs'] = [EEG['chanlocs']]
    else:
        EEG['chanlocs'] = []
    if isinstance(EEG['chanlocs'], list):
        EEG['chanlocs'] = np.asarray(EEG['chanlocs'], dtype=object)

    if 'chaninfo' not in EEG:
        EEG['chaninfo'] = {}
        
    if 'reject' not in EEG:
        EEG['reject'] = {}
        
    if 'data' in EEG and isinstance(EEG['data'], str) and load_data:
        # get path from file_path
        file_name = EEG['filepath'] + os.sep + EEG['data']
        if not os.path.exists(file_name):
            # try to use the sane name as the filename but with .fdt extension
            file_name = EEG['filepath'] + os.sep + EEG['filename'].replace('.set', '.fdt')
            if not os.path.exists(file_name):
                raise FileNotFoundError(f"Data file {file_name} not found")
        EEG['data'] = np.fromfile(file_name, dtype='float32').reshape( EEG['pnts']*EEG['trials'], EEG['nbchan'])
        EEG['data'] = EEG['data'].T.reshape(EEG['nbchan'], EEG['trials'], EEG['pnts']).transpose(0, 2, 1)

    # compute ICA activations
    if ('icaweights' in EEG and 'icasphere' in EEG and 
        hasattr(EEG['icaweights'], 'size') and hasattr(EEG['icasphere'], 'size') and
        EEG['icaweights'].size > 0 and EEG['icasphere'].size > 0):
        
        try:
            EEG['icaact'] = np.dot(np.dot(EEG['icaweights'], EEG['icasphere']), EEG['data'].reshape(int(EEG['nbchan']), -1))
            EEG['icaact'] = EEG['icaact'].astype(np.float32)
            EEG['icaact'] = EEG['icaact'].reshape(EEG['icaweights'].shape[0], -1, int(EEG['trials']))
        except exception_type as e:
            logger.error("Error computing ICA activations: " + str(e))
            EEG['icaact'] = np.array([])
    
    # check if EEG['data'] is 3D
    if 'data' in EEG and EEG['data'].ndim == 3:
        if EEG['data'].shape[2] == 1:
            EEG['data'] = np.squeeze(EEG['data'], axis=2)
     
    # type conversion
    EEG['xmin'] = float(EEG['xmin'])
    EEG['xmax'] = float(EEG['xmax'])
    EEG['srate'] = float(EEG['srate'])
         
    # Define the expected types
    expected_types = {
        'setname': str,
        'filename': str,
        'filepath': str,
        'subject': str,
        'group': str,
        'condition': str,
        'session': (str, int),
        'comments': np.ndarray,
        'nbchan': int,
        'trials': int,
        'pnts': int,
        'srate': (float,int),
        'xmin': float,
        'xmax': float,
        'times': np.ndarray,  # Expecting a float numpy array
        'data': np.ndarray,   # Expecting a float numpy array
        'icaact': np.ndarray, # Expecting a float numpy array
        'icawinv': np.ndarray,# Expecting a float numpy array
        'icasphere': np.ndarray, # Expecting a float numpy array
        'icaweights': np.ndarray, # Expecting a float numpy array
        'icachansind': np.ndarray, # Expecting an integer numpy array
        'chanlocs': np.ndarray,    # Expecting numpy array of dictionaries
        'urchanlocs': np.ndarray,  # Expecting numpy array of dictionaries
        'chaninfo': dict,
        'ref': str,
        'event': np.ndarray,       # Expecting numpy array of dictionaries
        'urevent': np.ndarray,     # Expecting numpy array of dictionaries
        'eventdescription': np.ndarray, # Expecting numpy array of strings
        'epoch': np.ndarray,       # Expecting numpy array of dictionaries
        'epochdescription': np.ndarray, # Expecting numpy array of strings
        'reject': dict,
        'stats': dict,
        'specdata': dict,
        'specicaact': dict,
        'splinefile': str,
        'icasplinefile': str,
        'dipfit': dict,
        'history': str,
        'saved': str,
        'etc': dict,
        'datfile': str,
        'run': (str, int),
        'roi': dict,
    }
    
    # Iterate through expected types and check input dictionary
    for field, expected_type in expected_types.items():
        if field not in EEG:
            print(f"Field '{field}' is missing from the EEG dictionnary, adding it.")
            
            # add default values
            if expected_type == str:
                EEG[field] = ''
            elif expected_type == int:
                EEG[field] = np.array([], dtype=int)
            elif expected_type == float:
                EEG[field] = np.array([], dtype=float)
            elif expected_type == dict:
                EEG[field] = {}
            elif expected_type == np.ndarray:
                EEG[field] = np.array([])
            else:
                EEG[field] = np.array([])
            continue
        
        value = EEG[field]
        
        # Special cases for numpy arrays with specific content types
        if isinstance(expected_type, type) and expected_type == np.ndarray:
            if not isinstance(value, np.ndarray):
                print(f"Field '{field}' is expected to be a numpy array but is of type {type(value).__name__}.")
                continue
            # Further checks for numpy array content types
            if field in ['times', 'data', 'icaact', 'icawinv', 'icasphere', 'icaweights']:
                if not np.issubdtype(value.dtype, np.floating):
                    print(f"Field '{field}' is expected to be a numpy array of floats but has dtype {value.dtype}.")
            elif field in ['icachansind']:
                if not np.issubdtype(value.dtype, np.integer):
                    print(f"Field '{field}' is expected to be a numpy array of integers but has dtype {value.dtype}.")
            elif field in ['chanlocs', 'urchanlocs', 'event', 'urevent', 'epoch']:
                if not all(isinstance(item, dict) for item in value):
                    print(f"Field '{field}' is expected to be a numpy array of dictionaries but contains other types.")
            # elif field in ['eventdescription', 'epochdescription']:
            #     if not all(isinstance(item, str) for item in value):
            #         print(f"Field '{field}' is expected to be a numpy array of strings but contains other types.")
        else:
            # General type check
            if not isinstance(value, expected_type):
                # check for empty Ndarray
                if isinstance(value, np.ndarray) and value.size == 0:
                    continue
                print(f"Field '{field}' is expected to be of type {expected_type} but is of type {type(value).__name__}.")  
    
    return EEG

def test_eeg_checkset():
    from eegprep.pop_loadset import pop_loadset

    eeglab_file_path = './data/eeglab_data_with_ica_tmp_out2.set'
    EEG = pop_loadset(eeglab_file_path)
    EEG = eeg_checkset(EEG)
    print('Checkset done')

if __name__ == '__main__':
    test_eeg_checkset()