#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Feb 27 22:36:21 2023

@author: CSChisholm

#These functions are primarily used for reading data generated by XMDS2
# (http://xmds.org).
#ParseXSIL is very specific to getting metadata out of the .xsil file generated
# by XMDS2.
#ReadH5 will load all of the outputs defined in XMDS2 into a Python dictionary
# but will also read any other HDF5 file into a dictionary.
#Likewise, WriteH5 will write any dictionary into an HDF5 file.
"""

import h5py
import xml.etree.ElementTree as ET
import numpy as np

def ReadH5(filename: str) -> dict:
    '''Opens hdf5 file and converts every group and subgroup into a dictionary
    where the keys are the group keys and the items are the datasets
    
    return data'''
    #Open file
    with h5py.File(filename,'r') as f:
        data = {}
        _Group2Dict(f,data)
    return data

def WriteH5(filename: str, data: dict):
    '''Takes dictionary and writes hdf5 file using keys as keys and items as
    datasets if they are not dictionaries or groups if they are dictionaries'''
    if not(isinstance(data,dict)):
        raise TypeError(f'Data must be a dictionary. Got type: {type(data)}.')
    with h5py.File(filename,'w') as f:
        _Dict2Group(f,data)
    return

def ParseXSIL(xsilFile: str) -> dict:
    ''''Parse most of the data from a .xsil file (XMDS2) into a dictionary. We don't need geometry or output because they are in the corresponding .h5 file.
    
    return data'''
    data = {'Filename': xsilFile} #Construct dictionary
    variables = {} #The variable definitions, will later be added to data
    comments = {} #Store the C++ comments as strings
    vectors = {} #Defined vectors (potential, initialisation etc.)
    operators = {} #Operators for FFT in split-step
    equations = {} #The differential equation(s) being solved
    axes = {}
    xmlData = ET.parse(xsilFile).getroot() #Load XML
    children = [child.tag for child in xmlData]
    data['Name'] = _SearchElement('name',xmlData,children).text #Get program name
    data['Description'] = _SearchElement('description',xmlData,children).text
    #We also want the file as a string so that we can access <![CDATA[ blocks
    with open(xsilFile,'r') as f:
        lines = f.readlines()
    #Search for vectors
    numVecs = len(np.where(np.array(children)=='vector')[0])
    for II in range(numVecs):
        vec = _SearchElement('vector',xmlData,children,index=II)
        name = vec.attrib['name']
        if not('kind' in vec[1].attrib.keys()):
            vectors[name] = [line for line in vec[1].text.split('\n') if '=' in line][0]
        else: #Initialised from hdf5 file
            vectors[name] = vec[1][0].text #Just put the file name here
    data['Vectors'] = vectors
    #Look in the sequence for operators, differential equations, and breakpoint
    sequence = _SearchElement('sequence',xmlData,children)
    try:
        data['Breakfile'] = _SearchElement('breakpoint',sequence).attrib['filename']
    except IndexError:
        data['Breakfile'] = None #no break file
    integrate = _SearchElement('integrate',sequence)
    data['Parameters'] = integrate.attrib
    operatorsIndex = [II for II, line in enumerate(lines) if '<operator_names>' in line][0]
    itr = operatorsIndex
    while not(']]>' in lines[itr]):
        if ('=' in lines[itr]):
            linesplit = lines[itr].split(' = ')
            operators[linesplit[0].split(' ')[-1]] = linesplit[1].split(';')[0]
        itr+=1
    data['Operators'] = operators
    equationsIndex = [II for II, line in enumerate(lines) if '<integration_vectors>' in line][0]
    itr = equationsIndex
    while not(']]>' in lines[itr]):
        if ('=' in lines[itr]):
            linesplit = lines[itr].split(' = ')
            equations[linesplit[0].split(' ')[-1]] = linesplit[1].split(';')[0]
        itr+=1
    data['Equations'] = equations
    #Search for command line arguments and store them in dict: variables
    info = _SearchElement('info',xmlData,children).text.split('\n')
    clargs = [line for line in info if 'Command line argument' in line]
    for clarg in clargs:
        definition = clarg.split('Command line argument ')[1]
        _GetVar(definition,variables)
        comments[definition.split(' = ')[0]] = 'Command line argument'
    #Get global variables
    features = _SearchElement('features',xmlData,children)
    globalvars = _SearchElement('globals',features).text.split('\n')
    globaldefs = [globalvar for globalvar in globalvars if 'const real' in globalvar]
    for globaldef in globaldefs:
        definition = globaldef.split('real ')[1].split(';')[0]
        _GetVar(definition,variables)
        if ('//' in globaldef):
            comments[definition.split(' = ')[0]] = globaldef.split('//')[1]
    #Get variables derived from command line arguments and globals
    subLines = [II for II, line in enumerate(lines) if '<arguments>' in line]
    if (len(subLines)):
        derivedIndex = subLines[0]
        itr = derivedIndex
        while not((']]>' in lines[itr]) or ('</features>' in lines[itr])):
            if (('=' in lines[itr]) and not('default_value' in lines[itr])):
                line = lines[itr].replace(' ','').replace('=',' = ')
                _GetVar(line.split(';')[0],variables)
                if ('//' in line):
                    comments[line.split(' = ')[0]] = line.split('//')[1]
            itr+=1
    data['Variables'] = variables
    data['Comments'] = comments
    geometry = _SearchElement('geometry',xmlData,children)
    transverseDimensions = _SearchElement('transverse_dimensions',geometry)
    for dimension in transverseDimensions:
        axes[dimension.attrib['name']] = {key: value for key, value in dimension.attrib.items() if not key=='name'}
    data['Axes'] = axes
    for key in data['Axes'].keys():
        if (('transform', 'bessel') in data['Axes'][key].items()):
            data['Variables'][f'{key}Outer'] = float(data['Axes'][key]['domain'].split(', ')[1].split(')')[0])
    return data

def _Group2Dict(h5Group,data):
    '''Helper function for ReadH5'''
    keys = [key for key in h5Group]
    for key in keys:
        if (isinstance(h5Group[key],h5py.Dataset)):
            data[key] = h5Group[key][()]
        else:
            data[key] = {}
            _Group2Dict(h5Group[key],data[key])
    return

def _Dict2Group(h5group,data):
    '''Helper function for WriteH5'''
    keys = data.keys()
    for key in keys:
        if (isinstance(data[key],dict)):
            group = h5group.create_group(key)
            _Dict2Group(group,data[key])
        else:
            h5group[key] = data[key]
    return

def _SearchElement(tag, root, tags=None,index=0):
    '''Search for the index of an xml element with a given tag.
    If the tag appears more than once in list(root) then the optional argument,
    index, can be used to choose which one to take
    Helper function for ParseXSIL
    
    return element'''
    if (tags is None):
        tags = [child.tag for child in root]
    ind = np.where(np.array(tags)==tag)[0][index]
    return root[ind]

def _GetVar(definition, variables):
    '''From a string, return a numeric definition of a variable,
    using previously evaluated variables if necessary
    Helper function for ParseXSIL'''
    splitDef = definition.split(' = ')
    varName = splitDef[0].replace(' ','')
    try:
        variables[varName] = float(splitDef[1])
    except ValueError:
        try:
            #We need to evaluate an expression
            expression = splitDef[1]
            #Variables are defined sequentially so any variable used in the expression already exists
            #Double pass on string replacement in case some variables contain other variables within their names
            expSplit = re.split('\*|/|\+|\-|\)|\(',expression)
            for part in expSplit:
                if (part in variables.keys()):
                    expression = expression.replace(part,str(variables[part]))
            for key in variables.keys():
                expression = expression.replace(key,str(variables[key]))
            #We also know that M_PI should be replaced by np.pi
            expression = expression.replace('M_PI',str(np.pi))
            #The C++ expressions which can be used are sqrt and pow, the easiest solution is to replace them with np.sqrt and np.power
            expression = expression.replace('sqrt','np.sqrt')
            expression = expression.replace('pow','np.power')
            #Now that variables are numeric we need to evaluate operators
            variables[varName] = eval(expression) #Maybe eval is not the best choice here but the input should be sufficiently cleaned by now
        except (NameError,SyntaxError):
            pass
    return