__all__ = [
    'CreateEvenRectilinearGrid',
    'CreateUniformGrid',
    'CreateTensorMesh',
]

__displayname__ = 'Grids'

import numpy as np
import vtk
from vtk.numpy_interface import dataset_adapter as dsa

from .. import _helpers, interface
from ..base import AlgorithmBase


def _makeSpatialCellData(nx, ny, nz):
    """Used for testing
    """
    arr = np.fromfunction(lambda k, j, i: k*j*i, (nz, ny, nz))
    return arr.flatten()


class CreateUniformGrid(AlgorithmBase):
    """Create uniform grid (``vtkImageData``)
    """
    __displayname__ = 'Create Uniform Grid'
    __category__ = 'source'
    def __init__(self,
                 extent=[10, 10, 10],
                 spacing=[1.0, 1.0, 1.0],
                 origin=[0.0, 0.0, 0.0]):
        AlgorithmBase.__init__(self,
            nInputPorts=0,
            nOutputPorts=1, outputType='vtkImageData')
        self.__extent = extent
        self.__spacing = spacing
        self.__origin = origin


    def RequestData(self, request, inInfo, outInfo):
        pdo = self.GetOutputData(outInfo, 0)
        nx,ny,nz = self.__extent[0],self.__extent[1],self.__extent[2]
        sx,sy,sz = self.__spacing[0],self.__spacing[1],self.__spacing[2]
        ox,oy,oz = self.__origin[0],self.__origin[1],self.__origin[2]
        # Setup the ImageData
        pdo.SetDimensions(nx, ny, nz)
        pdo.SetOrigin(ox, oy, oz)
        pdo.SetSpacing(sx, sy, sz)
        #pdo.SetExtent(0,nx-1, 0,ny-1, 0,nz-1)
        # Add CELL data
        data = _makeSpatialCellData(nx-1, ny-1, nz-1) # minus 1 b/c cell data not point data
        data = interface.convertArray(data, name='Spatial Cell Data', deep=True)
        # THIS IS CELL DATA! Add the model data to CELL data:
        pdo.GetCellData().AddArray(data)
        # Add Point data
        data = _makeSpatialCellData(nx, ny, nz)
        data = interface.convertArray(data, name='Spatial Point Data', deep=True)
        # THIS IS CELL DATA! Add the model data to CELL data:
        pdo.GetPointData().AddArray(data)
        return 1


    def RequestInformation(self, request, inInfo, outInfo):
        # Now set whole output extent
        ext = [0, self.__extent[0]-1, 0,self.__extent[1]-1, 0,self.__extent[2]-1]
        info = outInfo.GetInformationObject(0)
        # Set WHOLE_EXTENT: This is absolutely necessary
        info.Set(vtk.vtkStreamingDemandDrivenPipeline.WHOLE_EXTENT(), ext, 6)
        return 1


    #### Setters / Getters ####


    def SetExtent(self, nx, ny, nz):
        """Set the extent of the output grid.
        """
        if self.__extent != [nx, ny, nz]:
            self.__extent = [nx, ny, nz]
            self.Modified()

    def SetSpacing(self, dx, dy, dz):
        """Set the spacing for the points along each axial direction.
        """
        if self.__spacing != [dx, dy, dz]:
            self.__spacing = [dx, dy, dz]
            self.Modified()

    def SetOrigin(self, x0, y0, z0):
        """Set the origin of the output grid.
        """
        if self.__origin != [x0, y0, z0]:
            self.__origin = [x0, y0, z0]
            self.Modified()





class CreateEvenRectilinearGrid(AlgorithmBase):
    """This creates a vtkRectilinearGrid where the discretization along a
    given axis is uniformly distributed.
    """
    __displayname__ = 'Create Even Rectilinear Grid'
    __category__ = 'source'
    def __init__(self,
                 extent=[10, 10, 10],
                 xrng=[-1.0, 1.0],
                 yrng=[-1.0, 1.0],
                 zrng=[-1.0, 1.0]):
        AlgorithmBase.__init__(self,
            nInputPorts=0,
            nOutputPorts=1, outputType='vtkRectilinearGrid')
        self.__extent = extent
        self.__xrange = xrng
        self.__yrange = yrng
        self.__zrange = zrng


    def RequestData(self, request, inInfo, outInfo):
        # Get output of Proxy
        pdo = self.GetOutputData(outInfo, 0)
        # Perfrom task
        nx,ny,nz = self.__extent[0]+1, self.__extent[1]+1, self.__extent[2]+1

        xcoords = np.linspace(self.__xrange[0], self.__xrange[1], num=nx)
        ycoords = np.linspace(self.__yrange[0], self.__yrange[1], num=ny)
        zcoords = np.linspace(self.__zrange[0], self.__zrange[1], num=nz)

        # CONVERT TO VTK #
        xcoords = interface.convertArray(xcoords,deep=True)
        ycoords = interface.convertArray(ycoords,deep=True)
        zcoords = interface.convertArray(zcoords,deep=True)

        pdo.SetDimensions(nx,ny,nz)
        pdo.SetXCoordinates(xcoords)
        pdo.SetYCoordinates(ycoords)
        pdo.SetZCoordinates(zcoords)

        data = _makeSpatialCellData(nx-1, ny-1, nz-1)
        data = interface.convertArray(data, name='Spatial Data', deep=True)
        # THIS IS CELL DATA! Add the model data to CELL data:
        pdo.GetCellData().AddArray(data)
        return 1


    def RequestInformation(self, request, inInfo, outInfo):
        # Now set whole output extent
        ext = [0, self.__extent[0], 0,self.__extent[1], 0,self.__extent[2]]
        info = outInfo.GetInformationObject(0)
        # Set WHOLE_EXTENT: This is absolutely necessary
        info.Set(vtk.vtkStreamingDemandDrivenPipeline.WHOLE_EXTENT(), ext, 6)
        return 1


    #### Setters / Getters ####


    def SetExtent(self, nx, ny, nz):
        """Set the extent of the output grid.
        """
        if self.__extent != [nx, ny, nz]:
            self.__extent = [nx, ny, nz]
            self.Modified()

    def SetXRange(self, start, stop):
        """Set range (min, max) for the grid in the X-direction.
        """
        if self.__xrange != [start, stop]:
            self.__xrange = [start, stop]
            self.Modified()

    def SetYRange(self, start, stop):
        """Set range (min, max) for the grid in the Y-direction
        """
        if self.__yrange != [start, stop]:
            self.__yrange = [start, stop]
            self.Modified()

    def SetZRange(self, start, stop):
        """Set range (min, max) for the grid in the Z-direction
        """
        if self.__zrange != [start, stop]:
            self.__zrange = [start, stop]
            self.Modified()



class CreateTensorMesh(AlgorithmBase):
    """This creates a vtkRectilinearGrid where the discretization along a
    given axis is uniformly distributed.
    """
    __displayname__ = 'Create Tensor Mesh'
    __category__ = 'source'
    def __init__(self, origin=[-350.0, -400.0, 0.0], dataname='Data',
            xcellstr='200 100 50 20*50.0 50 100 200',
            ycellstr='200 100 50 21*50.0 50 100 200',
            zcellstr='20*25.0 50 100 200',):
        AlgorithmBase.__init__(self, nInputPorts=0,
            nOutputPorts=1, outputType='vtkRectilinearGrid')
        self.__origin = origin
        self.__xcells = CreateTensorMesh._ReadCellLine(xcellstr)
        self.__ycells = CreateTensorMesh._ReadCellLine(ycellstr)
        self.__zcells = CreateTensorMesh._ReadCellLine(zcellstr)
        self.__dataName = dataname


    @staticmethod
    def _ReadCellLine(line):
        """Read cell sizes for each line in the UBC mesh line strings
        """
        # OPTIMIZE: work in progress
        # TODO: when optimized, make sure to combine with UBC reader
        line_list = []
        for seg in line.split():
            if '*' in seg:
                sp = seg.split('*')
                seg_arr = np.ones((int(sp[0]),), dtype=float) * float(sp[1])
            else:
                seg_arr = np.array([float(seg)], dtype=float)
            line_list.append(seg_arr)
        return np.concatenate(line_list)


    def GetExtent(self):
        ne,nn,nz = len(self.__xcells), len(self.__ycells), len(self.__zcells)
        return (0,ne, 0,nn, 0,nz)



    def _MakeModel(self, pdo):
        ox,oy,oz = self.__origin[0], self.__origin[1], self.__origin[2]

        # Read the cell sizes
        cx = self.__xcells
        cy = self.__ycells
        cz = self.__zcells

        # Invert the indexing of the vector to start from the bottom.
        cz = cz[::-1]
        # Adjust the reference point to the bottom south west corner
        oz = oz - np.sum(cz)

        # Now generate the coordinates for from cell width and origin
        cox = ox + np.cumsum(cx)
        cox = np.insert(cox,0,ox)
        coy = oy + np.cumsum(cy)
        coy = np.insert(coy,0,oy)
        coz = oz + np.cumsum(cz)
        coz = np.insert(coz,0,oz)

        # Set the dims and coordinates for the output
        ext = self.GetExtent()
        nx,ny,nz = ext[1]+1,ext[3]+1,ext[5]+1
        pdo.SetDimensions(nx,ny,nz)
        # Convert to VTK array for setting coordinates
        pdo.SetXCoordinates(interface.convertArray(cox, deep=True))
        pdo.SetYCoordinates(interface.convertArray(coy, deep=True))
        pdo.SetZCoordinates(interface.convertArray(coz, deep=True))

        return pdo


    def _AddModelData(self, pdo, data):
        nx, ny, nz = pdo.GetDimensions()
        nx, ny, nz = nx-1, ny-1, nz-1
        # ADD DATA to cells
        if data is None:
            data = np.random.rand(nx*ny*nz)
            data = interface.convertArray(data, name='Random Data', deep=True)
        else:
            data = interface.convertArray(data, name=dataNm, deep=True)
        pdo.GetCellData().AddArray(data)
        return pdo


    def RequestData(self, request, inInfo, outInfo):
        """Used by pipeline to generate output data object
        """
        # Get input/output of Proxy
        pdo = self.GetOutputData(outInfo, 0)
        # Perform the task
        self._MakeModel(pdo)
        self._AddModelData(pdo, None) # TODO: add ability to set input data
        return 1


    def RequestInformation(self, request, inInfo, outInfo):
        """Used by pipeline to set output whole extent
        """
        # Now set whole output extent
        ext = self.GetExtent()
        info = outInfo.GetInformationObject(0)
        # Set WHOLE_EXTENT: This is absolutely necessary
        info.Set(vtk.vtkStreamingDemandDrivenPipeline.WHOLE_EXTENT(), ext, 6)
        return 1


    #### Getters / Setters ####


    def SetOrigin(self, x0, y0, z0):
        """Set the origin of the output
        """
        if self.__origin != [x0, y0, z0]:
            self.__origin = [x0, y0, z0]
            self.Modified()

    def SetXCells(self, xcells):
        """Set the spacings for the cells in the X direction

        Args:
            xcells (list or np.array(floats)) : the spacings along the X-axis"""
        if len(xcells) != len(self.__xcells) or not np.allclose(self.__xcells, xcells):
            self.__xcells = xcells
            self.Modified()

    def SetYCells(self, ycells):
        """Set the spacings for the cells in the Y direction

        Args:
            ycells (list or np.array(floats)) : the spacings along the Y-axis"""
        if len(ycells) != len(self.__ycells) or not np.allclose(self.__ycells, ycells):
            self.__ycells = ycells
            self.Modified()

    def SetZCells(self, zcells):
        """Set the spacings for the cells in the Z direction

        Args:
            zcells (list or np.array(floats)): the spacings along the Z-axis"""
        if len(zcells) != len(self.__zcells) or not np.allclose(self.__zcells, zcells):
            self.__zcells = zcells
            self.Modified()

    def SetXCellsStr(self, xcellstr):
        """Set the spacings for the cells in the X direction

        Args:
            xcellstr (str) : the spacings along the X-axis in the UBC style"""
        xcells = CreateTensorMesh._ReadCellLine(xcellstr)
        self.SetXCells(xcells)

    def SetYCellsStr(self, ycellstr):
        """Set the spacings for the cells in the Y direction

        Args:
            ycellstr (str) : the spacings along the Y-axis in the UBC style"""
        ycells = CreateTensorMesh._ReadCellLine(ycellstr)
        self.SetYCells(ycells)

    def SetZCellsStr(self, zcellstr):
        """Set the spacings for the cells in the Z direction

        Args:
            zcellstr (str)  : the spacings along the Z-axis in the UBC style"""
        zcells = CreateTensorMesh._ReadCellLine(zcellstr)
        self.SetZCells(zcells)
