import numpy as np
from pathlib import Path
import xml.etree.ElementTree as ET
import itertools
from nmrtoolbox.util import PeakOutsideOfMaskError, AxisData, SpectrumData, Range


class Mask:
    """
    A class for defining a subset of the Nyquist grid.  The intended use is for NUScon where synthetic peaks are
    injected into the empty regions of an empirical spectrum.  This mask can define the empty regions so that the
    recovered peak list can be restricted to just regions where the synthetic peaks have been injected.
    The Mask is a numpy array sized to match the Nyquist grid with bool values where True indicates empty.
    """

    def __init__(self, file):
        """
        :param file: file containing the empty regions of an empirical spectrum, as generated by CONNJUR
        """
        try:
            self.file = Path(file)
        except TypeError:
            raise TypeError(f'provide path to mask file.  input path not recognized: {file}')
        if not self.file.is_file():
            raise FileNotFoundError(f'did not find mask file: {self.file}')

        # the xml includes custom namespace prefixes which need to be defined and used by xml.find()
        self.ns = {
            'connjur': "https://raw.githubusercontent.com/CONNJUR/CONNJUR_ML/master/connjur_ml.xsd",
            'builder': "https://raw.githubusercontent.com/CONNJUR/CONNJUR_ML/master/connjur_ml.xsd",
            'xsi': "http://www.w3.org/2001/XMLSchema-instance",
            'premis': "http://www.loc.gov/premis/v3",
        }

        # parse the tabular data file - capture the sections of the file and several properties
        self.header = []
        self.xml = None
        self.mask = None

        # containers for information about the axes and about the spectrum (i.e. global properties)
        self.axis = AxisData()
        self.meta = SpectrumData()

        # read the mask data and derive additional properties that go in self.axis and self.meta
        self._parse_file()
        try:
            self._get_axis_properties()
        except Exception as e:
            raise ValueError(f'unable to parse mask file\n{e}')

    def _parse_file(self):
        """parse tabular file that CWB generates - it contains a header, xml metadata, and tabular data"""

        xml_list = []

        with open(self.file) as file:
            mode = 'header'
            while line := file.readline().rstrip():
                # -------------------------------------------
                # 1. read the plain text header
                if mode == 'header':
                    if line.startswith('<?xml'):
                        # found the xml - move onto the next parsing stage with this line
                        mode = 'xml'
                    else:
                        # reading the plain text header
                        self.header.append(line)
                        continue
                # -------------------------------------------
                # 2. read the xml metadata
                if mode == 'xml':
                    # capture the line
                    xml_list.append(line)

                    if '</premis:premis>' in line:
                        # found the end of the xml header => parse it, init the mask, and move on
                        self.xml = ET.fromstringlist(xml_list)

                        self.meta.set(
                            name='ndims',
                            value=self.get_xml('connjur:spectralAxes', dtype=int)[0],
                        )

                        # this is the first data written to self.axis, so an explicit axis_list must be provided
                        # subsequent additions to self.axis will use existing axes
                        self.axis.set_list(
                            name='points',
                            value_list=self.get_xml('connjur:totalPoints', dtype=int),
                            axis_list=self.get_xml('connjur:decoupledNucleus', dtype=str),
                        )

                        self.mask = np.ones(self.axis.get_field('points'), dtype=bool)

                        # move onto the next parsing stage
                        mode = 'predata'

                    continue
                # -------------------------------------------
                # 3. skip anything that might appear after xml and before tabular data
                if mode == 'predata':
                    if "tabular data follows" in line:
                        # move onto next parsing stage
                        mode = 'table_header'
                        continue
                # -------------------------------------------
                # 4. skip table header
                if mode == 'table_header':
                    mode = 'data'
                    continue
                # -------------------------------------------
                # 5. read the tabular data
                # build a mask from the tabular data
                if mode == 'data':
                    vals = line.split()
                    # NOTE: since this is a mask the values are expected to be 0.0000 (not empty) or 1.0000 (empty)
                    # need float() to parse the string and then convert to bool()
                    # NOTE: mask is initialized as True, so just find and record the False values
                    # (there should be far fewer non-empty values (False) in the mask)
                    if not bool(float(vals[-1])):
                        indel = tuple(map(int, vals[:self.meta.get_field('ndims')]))
                        # ignoring phase for now
                        # phase = vals[-2]
                        self.mask[indel] = False
                    continue

    def _get_axis_properties(self):
        """
        Use metadata in xml of the mask file to compute how many points there are per ppm along each axis.  This will be
        used to convert an arbitrary peak's position (in ppm) into its corresponding location in the mask (in pts).
        """

        # the mask files include SW in Hz, but not in ppm.  use carrier frequency (sf) to compute SW in ppm
        sw_hz = self.get_xml('connjur:sweepwidth')
        sf_mhz = self.get_xml('connjur:spectralFrequency')
        first_hz = self.get_xml('connjur:firstScalePoint')

        # sweep width of each axis in ppm
        sw_ppm = sw_hz / sf_mhz

        # number of points along each axis per ppm (n points means n-1 intervals)
        self.axis.set_list(
            name='pts_per_ppm',
            value_list=(np.asarray(self.axis.get_field('points')) - 1) / sw_ppm,
        )

        # The "first" point along each axis in ppm.  NOTE: This is the smaller value of the ppm range furthest "right".
        # The "last" point along each axis in ppm.   NOTE: This is the largest value of the ppm range furthest "left".
        first_pt_ppm = first_hz / sf_mhz
        last_pt_ppm = (first_hz + sw_hz) / sf_mhz

        value_list = [Range(min=first, max=last) for first, last in zip(first_pt_ppm, last_pt_ppm)]
        self.axis.set_list(
            name='range',
            value_list=value_list,
            unit='ppm',
        )

    def print_info(self):
        num_indels = self.mask.size
        empty_num = np.count_nonzero(self.mask)
        nonempty_num = num_indels - empty_num
        empty_percent = (empty_num / num_indels) * 100
        nonempty_percent = (nonempty_num / num_indels) * 100

        print(f'empty indels:    {empty_num:12,d} / {empty_percent:5.2f}%')
        print(f'nonempty indels: {nonempty_num:12,d} / {nonempty_percent:5.2f}%')

    def get_xml(self, name, dtype=float):
        """
        Get the value(s) listed in the xml header for the field "name" and convert to dtype.
        """
        return np.array([e.text for e in self.xml.findall(f'.//{name}', self.ns)], dtype=dtype)

    def ppm_in_mask(self, peak_ppm, peak_axis_property, box_radius=2):
        """
        map the position of the peak in ppm onto the indel coordinates (points) used by the mask

        consider box_radius points along each dimension above and below the (off-grid) peak position.
        this defines a hypercube around the peak position.  the peak position is considered to be in an empty region if
        any of the points in the hypercube are empty in the mask.

        box_radius can be given as single value (used for all dims) or as a list of values (one for each dim)
        """
        if not isinstance(peak_axis_property, AxisData):
            raise TypeError('must provide info about axes as AxisData object')
        if peak_axis_property.keys() != self.axis.keys():
            raise ValueError('axis labels for peak do not match labels for mask')

        try:
            peak_ppm = np.asarray(peak_ppm)
        except ValueError:
            raise ValueError('failed to convert peak position to numpy array')

        if isinstance(box_radius, int):
            num_dims = len(peak_ppm)
            box_radius = np.repeat(box_radius, num_dims)
        try:
            # convert to array of int and only allow conversions to int where values are preserved
            box_radius = np.asarray(box_radius).astype(int, casting='safe')
        except TypeError:
            raise TypeError('failed to convert box_radius into numpy.array of integers')

        if box_radius.shape != peak_ppm.shape:
            raise ValueError('box radius and peak position do not have same dimensions')

        for axis_index, p in enumerate(peak_ppm):
            # TODO:
            #  temporary solution - for loop should iterate on axis keys and not on an index
            #  (could solve this by implementing peak_in_mask, rather than ppm_in_mask)
            axis = self.axis.keys()[axis_index]
            ppm_range = self.axis.get_field(name='range', axis=axis)[0]
            if not ppm_range.contains(p):
                raise PeakOutsideOfMaskError(f'peak position is not in the range of the mask: {peak_ppm}')

        # use position of peak relative to the "last" point  to get index location
        last_pt_ppm = [r.max for r in self.axis.get_field('range')]
        peak_pts = np.abs(peak_ppm - np.asarray(last_pt_ppm)) * np.asarray(self.axis.get_field('pts_per_ppm'))

        # position of peak in points will not be on-grid, so find all neighboring indels
        peak_pts_down = np.floor(peak_pts).astype(int) - (box_radius - 1)
        peak_pts_up = np.ceil(peak_pts).astype(int) + (box_radius - 1)
        peak_boundaries = [list(range(down, up+1)) for down, up in zip(peak_pts_down, peak_pts_up)]

        # check if each neighbor is empty
        for indel in itertools.product(*peak_boundaries):
            if min(indel) < 0:
                # if the box region goes off grid into negative index, then python would treat that as
                # indexing from end of the array - do NOT want that - just skip the indel as outside the mask
                continue
            try:
                if self.mask[indel]:
                    # found a neighbor that is empty
                    return True
            except IndexError:
                # peak position is at the edge and its "neighbor" is outside the mask
                continue
        # none of the neighbors are empty!
        return False
