import meshio
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import linecache
import os
# read internal data file generated by foamToVTK utility
def readInternal(caseDir, m2mm=1, meshGenerator='blockMesh'):
    caseDir=os.path.abspath(caseDir)
    vtkfile = '%s/VTK/%s_0.vtk'%(caseDir,caseDir.split('/')[-1])
    if(not os.path.exists(vtkfile)):
        cmd='foamToVTK -useTimeName -case %s'%(caseDir)
        os.system(cmd)
        if(not os.path.exists(vtkfile)):
            print("The vtk file doesn't exist, please run `foamToVTK` for the case.\n%s"%(vtkfile))
            print('Tried to run command for you: %s'%(cmd))
            print("but the problem is still not solved, do it by hand please")
            return
    data=meshio.read(vtkfile)
    cells=data.cells[0].data
    points=data.points
    cells_poly=[]
    if(meshGenerator=='blockMesh'):
        cells_poly=np.zeros((cells.shape[0],4),dtype=int)
        for i in range(0,cells.shape[0]):
            points_cell=points[cells[i]]
            ind_front=(points_cell[:,2]==points_cell[:,2].max())
            cells_poly[i,:]=cells[i,ind_front]
        cells_poly=cells_poly[:,[0,1,3,2]] # correct node connection order to construct rectangule
        cells_poly=cells_poly
    elif(meshGenerator=='gmsh'):
        cells_tri=np.zeros((cells.shape[0],3),dtype=int)
        for i in range(0,cells.shape[0]):
            points_cell=points[cells[i]]
            ind_front=(points_cell[:,2]==points_cell[:,2].max())
            cells_tri[i,:]=cells[i,ind_front]
        cells_poly=cells_tri
    x,y,z=points[:,0],points[:,1],points[:,2]
    print("nPoints: %d, nPoints_2D: %d, nCells: %d"%(len(points), len(points)/2, len(cells)))
    return x*m2mm, y*m2mm,z*m2mm, cells_poly
# return python dict, {'x':[],'y':[],'z':[]}
def readPoints(caseDir):
    caseDir=os.path.abspath(caseDir)
    file_points='%s/constant/polyMesh/points'%(caseDir)
    alldata=linecache.getlines(file_points)
    linecache.clearcache()
    nPoints, start, end=0,0,0
    points={'x':[],'y':[],'z':[]}
    for i in range(0,len(alldata)):
        alldata[i]=alldata[i].replace('\n','')
        if(alldata[i]=='('):
            nPoints=int(alldata[i-1])
            start=i+1
        elif(alldata[i]==')'):
            end=i
    # read coordinates
    for i in range(start, end):
        xyz=alldata[i].replace('(','').replace(')','').split()
        for key, ind in zip(points.keys(),[0,1,2]):
            points[key].append(xyz[ind])
    for key in points.keys():
        points[key]=np.array(points[key],dtype=float)
    return points
# read OpenFOAM poly Mesh
# return a python dict {'nNodes':[], 'index':[]}
def readFaces(caseDir):
    caseDir=os.path.abspath(caseDir)
    file_faces='%s/constant/polyMesh/faces'%(caseDir)
    alldata=linecache.getlines(file_faces)
    linecache.clearcache()
    nFaces,face_start,face_end=0,0,0
    for i in range(0,len(alldata)):
        alldata[i]=alldata[i].replace('\n','')
        if(alldata[i]=='('):
            nFaces=int(alldata[i-1])
            face_start=i+1
        elif(alldata[i]==')'):
            face_end=i-1
    if(not nFaces==(face_end-face_start+1)):
        print('The faces file of case %s maybe not correct, because the face number is not consistant'%(caseDir))
    # read point index of each face
    faces={'nNodes':[],'index':[]}
    for i in range(face_start, face_end+1):
        n=int(alldata[i].split('(')[0])
        index=alldata[i].split('(')[1].split(')')[0].split(' ')
        faces['nNodes'].append(n)
        faces['index'].append(np.array(index,dtype=int))
    return faces
# read owner 
# return a int array, the length is equal to number of all faces
def readOwner(caseDir):
    caseDir=os.path.abspath(caseDir)
    file_faces='%s/constant/polyMesh/owner'%(caseDir)
    alldata=linecache.getlines(file_faces)
    linecache.clearcache()
    nFaces,face_start,face_end=0,0,0
    for i in range(0,len(alldata)):
        alldata[i]=alldata[i].replace('\n','')
        if(alldata[i]=='('):
            nFaces=int(alldata[i-1])
            face_start=i+1
        elif(alldata[i]==')'):
            face_end=i-1
    if(not nFaces==(face_end-face_start+1)):
        print('The faces file of case %s maybe not correct, because the face number is not consistant'%(caseDir))
    # read point index of each face
    owners=np.array(alldata[face_start:face_end+1], dtype=int)
    return owners
# read neighbour 
# return a int array, the length is equal to number of internal faces
def readNeighbour(caseDir):
    caseDir=os.path.abspath(caseDir)
    file_faces='%s/constant/polyMesh/neighbour'%(caseDir)
    alldata=linecache.getlines(file_faces)
    linecache.clearcache()
    nFaces,face_start,face_end=0,0,0
    for i in range(0,len(alldata)):
        alldata[i]=alldata[i].replace('\n','')
        if(alldata[i]=='('):
            nFaces=int(alldata[i-1])
            face_start=i+1
        elif(alldata[i]==')'):
            face_end=i-1
    if(not nFaces==(face_end-face_start+1)):
        print('The faces file of case %s maybe not correct, because the face number is not consistant'%(caseDir))
    # read point index of each face
    neighbours=np.array(alldata[face_start:face_end+1], dtype=int)
    return neighbours
# read boundary patches information 
# return a python dict {'name':[], 'nFaces':[], 'startFace':[], 'type':[], 'index':[]}
def readBoundary(caseDir, nAllFaces=None):
    caseDir=os.path.abspath(caseDir)
    file_boundary='%s/constant/polyMesh/boundary'%(caseDir)
    alldata=linecache.getlines(file_boundary)
    linecache.clearcache()
    nBoundaries, start, end=0,0,0
    boundaries={'name':[], 'nFaces':[], 'startFace':[], 'type':[],'index':[]}
    for i in range(0,len(alldata)):
        alldata[i]=alldata[i].replace('\n','')
        if(alldata[i]=='('):
            nBoundaries=int(alldata[i-1])
            start=i+1
        elif(alldata[i]==')'):
            end=i
    # print(start, end)
    # get start line and end line of each patch
    start_patch,end_patch=[],[]
    for i in range(start,end):
        if(len(alldata[i])>0):
            if(alldata[i][-1]=='{'): # a patch start
                start_patch.append(i)
            if(alldata[i][-1]=='}'):
                end_patch.append(i)
    if((not (nBoundaries==len(start_patch))) & (not (nBoundaries==len(start_patch)))):
        print('boundary file parse failure, because boundary number are not consistant: %f'%(caseDir))
    for start, end in zip(start_patch, end_patch):
        boundaries['name'].append(alldata[start-1].split()[0])
        for i in range(start, end):
            if('type' in alldata[i]):
                boundaries['type'].append(alldata[i].split()[1].split(';')[0])
            if('nFaces' in alldata[i]):
                boundaries['nFaces'].append(int(alldata[i].split()[1].split(';')[0]))
            if('startFace' in alldata[i]):
                boundaries['startFace'].append(int(alldata[i].split()[1].split(';')[0]))
    index_allBoundaries=[]
    for n, start in zip(boundaries['nFaces'],boundaries['startFace']):
        index = np.arange(start, n+start).tolist()
        boundaries['index'].append(index)
        index_allBoundaries = index_allBoundaries +index
    print('nBoundaries: %d, '%(nBoundaries), boundaries['name'])
    # calculate all internal faces index
    index_internalFaces=[]
    name_faces=[]
    if(not nAllFaces==None):
        name_faces=['internal']*nAllFaces
        # get all index of all internal faces
        inds_faces=np.array([True]*nAllFaces)
        inds_faces[index_allBoundaries]=False
        inds_internalFaces=(inds_faces==True)
        index_faces=np.arange(0,nAllFaces)
        index_internalFaces=index_faces[inds_internalFaces]
        for name, index in zip(boundaries['name'],boundaries['index']):
            for ind in index:
                name_faces[ind]=name
        print('nInternalFaces: %d'%(len(index_internalFaces)))
    return boundaries,index_internalFaces,name_faces
# return python 2D list, [faces], faces=[face_1, face_2, ..., face_n]
def getCells(owners, neighbours):
    nCells=np.max([owners.max(),neighbours.max()]) + 1
    cells=[]
    for i in range(0,nCells):
        cells.append([])
    for i in range(0, len(owners)):
        cells[owners[i]].append(i)
    for i in range(0, len(neighbours)):
        cells[neighbours[i]].append(i)
    return cells
def read(caseDir):
    caseDir=os.path.abspath(caseDir)
    points=readPoints(caseDir)
    x,y,z=points['x'],points['y'],points['z']
    faces = readFaces(caseDir)
    owners = readOwner(caseDir)
    neighbours = readNeighbour(caseDir)
    faceIndex_cells=getCells(owners, neighbours)
    boundaries, index_internalFaces, name_faces = readBoundary(caseDir,len(faces['nNodes']))
    faces['name']=name_faces
    # 1. get empty patch name
    name_emptyPatch=[]
    for name_patch, type_patch in zip(boundaries['name'],boundaries['type']):
        if(type_patch=='empty'):
            name_emptyPatch.append(name_patch)
    # 2. neighbour cells of a cell
    cells={'faces':faceIndex_cells,'neighbour':[],'owner':[]}
    for i in range(0,len(faceIndex_cells)):
        cells['owner'].append([])
        cells['neighbour'].append([])
    # 2.1 faces own to cell
    for ind_cell in range(0,len(cells['faces'])):
        faces_own_to_cell = np.where(owners==ind_cell)[0]
        for face in faces_own_to_cell:
            if(faces['name'][face] in name_emptyPatch):
                continue
            if(faces['name'][face]=='internal'):
                cells['owner'][ind_cell].append(neighbours[face])
        # 2.2 faces neighbour to cell
        faces_neighbour_to_cell=np.where(neighbours==ind_cell)[0]
        for face in faces_neighbour_to_cell:
            if(faces['name'][face] in name_emptyPatch):
                continue
            cells['neighbour'][ind_cell].append(owners[face])
    mesh={'points':points, 'cells':cells, 'faces':faces, 'owners': owners, 'neighbours':neighbours}
    return mesh
def plotMeshTopology(ax,caseDir, ind_cell=None,index_intFace=None, meshGenerator='blockMesh',**kwargs):
    caseDir=os.path.abspath(caseDir)
    points=readPoints(caseDir)
    x,y,z=points['x'],points['y'],points['z']
    faces = readFaces(caseDir)
    owners = readOwner(caseDir)
    neighbours = readNeighbour(caseDir)
    cells=getCells(owners, neighbours)
    boundaries, index_internalFaces, name_faces = readBoundary(caseDir,len(faces['nNodes']))
    faces['name']=name_faces
    # get empty patch name
    name_emptyPatch=[]
    for name_patch, type_patch in zip(boundaries['name'],boundaries['type']):
        if(type_patch=='empty'):
            name_emptyPatch.append(name_patch)
    # 1. plot front face (rectangle) of each cell and cell index in the rect center
    cells_poly=[0]*len(cells)
    label='Cell index'
    for i in range(0,len(cells)):
        x_cell,y_cell=[],[]
        for face in cells[i]:
            if(name_faces[face] in name_emptyPatch):
                cells_poly[i]=faces['index'][face]
                x_cell,y_cell=x[faces['index'][face]],y[faces['index'][face]]
                ax.plot(x_cell.mean(), y_cell.mean(),'o',mfc='lightskyblue',mec='k',ms=15,label=label,**kwargs)
                ax.text(x_cell.mean(), y_cell.mean(),str('%d'%(i)), va='center',ha='center')
                label=None
                break
    # 2. plot face of the front patch, startFace and nFaces of a patch can be found in constant/polyMesh/boundary file
    # index_face_front=115+11
    # ax.fill(x[faces['index'][index_face_front]],y[faces['index'][index_face_front]],fc='lightgray', alpha=0.5, label='The %d$_{th}$ face on front patch'%(index_face_front))
    # 3. plot all internal face 
    for i in range(0,len(index_internalFaces)):
        index_face_internal = index_internalFaces[i]
        if(i==0):
            label='Internal face: %d'%(len(index_internalFaces))
        else:
            label=None
        index_points_internalFace=faces['index'][index_face_internal]
        x_face,y_face,z_face=x[index_points_internalFace], y[index_points_internalFace],z[index_points_internalFace]
        ax.plot(x_face,y_face,'k',label=label,lw=1)
        # norm=np.cross([x_face[1] - x_face[0], y_face[1] - y_face[0], z_face[1] - z_face[0]],
        #              [x_face[2] - x_face[1], y_face[2] - y_face[1], z_face[2] - z_face[1]])
        # norm=norm[0:2]/np.sqrt(np.sum(norm**2))
        # theta=90-np.arccos(norm[1])/np.pi*180
        ax.text(x_face.mean(),y_face.mean(),'%d'%(index_face_internal),va='center',ha='center',color='k',bbox={'color':'lightgray'})
    # 4. plot all boundary patches
    for name, patchType, patchIndex,lc in zip(boundaries['name'],boundaries['type'],boundaries['index'],plt.rcParams['axes.prop_cycle'].by_key()['color']):

        if(name in name_emptyPatch): # skip front and back patches, this is a 2D case!!!
            continue
        for i in range(0,len(patchIndex)):
            index_face_patch = patchIndex[i]
            if(i==0):
                label='%s(%s): %d'%(name,patchType,len(patchIndex))
            else:
                label=None
            index_points_patchFace=faces['index'][index_face_patch]
            x_tmp,y_tmp=x[index_points_patchFace], y[index_points_patchFace]
            ax.plot(x_tmp,y_tmp,color=lc,label=label,**kwargs)
            rot= 90 if(x_tmp.min()==x_tmp.max()) else 0
            ax.text(x_tmp.mean(),y_tmp.mean(),'%d:%d'%(i,index_face_patch),va='center',ha='center',rotation=rot, color='k', bbox={'color':'lightgray'}, alpha=0.5)
    # 5. plot a internal face and marker its owner and neighbour cell
    index_intFace = int(len(index_internalFaces)/2) if (index_intFace==None) else index_intFace
    ax.plot(x[faces['index'][index_intFace]], y[faces['index'][index_intFace]],'r', label='The %d$_{th}$ internal face'%(index_intFace),**kwargs)
    # print(owners[index_intFace],neighbours[index_intFace])
    ax.fill(x[cells_poly[owners[index_intFace]]],  y[cells_poly[owners[index_intFace]]], fc='dodgerblue',label='Owner cell of face %d'%(index_intFace))
    ax.fill(x[cells_poly[neighbours[index_intFace]]],  y[cells_poly[neighbours[index_intFace]]], fc='purple',label='Neighbour cell of face %d'%(index_intFace))

    # 6. plot a cell and its neighbour cells and faces
    ind_cell= int(len(cells_poly)/2) if(ind_cell==None) else ind_cell
    x_cell,y_cell=x[cells_poly[ind_cell]], y[cells_poly[ind_cell]]
    ax.fill(x_cell,y_cell, label='The %d$_{th}$ cell'%(ind_cell),fc='limegreen')
    len_diag=np.sqrt(np.sum(np.array([x_cell.max() - x_cell.min(), y_cell.max()-y_cell.min()])**2))
    ax.text(x_cell.mean()-len_diag/8,y_cell.mean()-len_diag/8, '$C$', color='w', fontsize=16, fontweight='bold', ha='center', va='center')
    # 6.1 faces own to cell
    index_local_cells, index_local_faces=1,1
    faces_own_to_cell = np.where(owners==ind_cell)[0]
    for face in faces_own_to_cell:
        lc='r'
        x_face,y_face,z_face=x[faces['index'][face]], y[faces['index'][face]], z[faces['index'][face]]
        # calculate normal vector of the face
        norm=np.cross([x_face[1] - x_face[0], y_face[1] - y_face[0], z_face[1] - z_face[0]],
                     [x_face[2] - x_face[1], y_face[2] - y_face[1], z_face[2] - z_face[1]])
        norm=norm[0:2]/np.sqrt(np.sum(norm**2))
        if(faces['name'][face] in name_emptyPatch):
            continue
        if(faces['name'][face]=='internal'):
            x_cell_w,y_cell_w=x[cells_poly[neighbours[face]]], y[cells_poly[neighbours[face]]]
            ax.fill(x_cell_w,y_cell_w,fc='gray',alpha=0.8)
            dist_cells = np.sqrt(np.sum((np.array([x_cell_w.mean(), y_cell_w.mean()]) - np.array([x_cell.mean(), y_cell.mean()]))**2))
            ax.text(x_cell_w.mean()-norm[0]*dist_cells*0.2, y_cell_w.mean()-norm[1]*dist_cells*0.2,
                    '$F_{\mathregular{%d}}$'%(index_local_cells),va='center',ha='center',color='w', fontsize=16)
            index_local_cells = index_local_cells+1
        if(faces['name'][face] in boundaries['name']): # point out if face is a boundary face
            lc='cyan'
            # norm=-norm
        lf,=ax.plot(x_face,y_face, color=lc,**kwargs)
        # plot the normal vector of the face
        xy_cf=np.array([x_face.mean(), y_face.mean()])
        len_arrow=np.sqrt(np.sum(np.array([x_face.max()-x_face.min(),y_face.max()-y_face.min()])**2))/3
        ax.annotate("", xy=xy_cf+norm*len_arrow/2, xytext=xy_cf-norm*len_arrow/2,arrowprops=dict(arrowstyle="->",color=lf.get_color()))
        # plot local face index
        norm_orth=[-norm[1],norm[0]] # one orthogonal vector of the norm vector
        ax.text(xy_cf[0]+norm_orth[0]*len_arrow/2, xy_cf[1]+norm_orth[1]*len_arrow/2, '$f_{%d}$'%(index_local_faces), fontsize=16, ha='center', va='center',color='w')
        index_local_faces = index_local_faces+1
    # 6.2 faces neighbour to cell
    faces_neighbour_to_cell=np.where(neighbours==ind_cell)[0]
    for face in faces_neighbour_to_cell:
        if(faces['name'][face] in name_emptyPatch):
            continue
        x_face,y_face,z_face=x[faces['index'][face]], y[faces['index'][face]], z[faces['index'][face]]
        norm=np.cross([x_face[1] - x_face[0], y_face[1] - y_face[0], z_face[1] - z_face[0]],
                     [x_face[2] - x_face[1], y_face[2] - y_face[1], z_face[2] - z_face[1]])
        norm=norm[0:2]/np.sqrt(np.sum(norm**2))
        x_cell_n,y_cell_n=x[cells_poly[owners[face]]], y[cells_poly[owners[face]]]
        ax.fill(x_cell_n, y_cell_n,fc='gray',alpha=0.8)
        dist_cells = np.sqrt(np.sum((np.array([x_cell_n.mean(), y_cell_n.mean()]) - np.array([x_cell.mean(), y_cell.mean()]))**2))
        ax.text(x_cell_n.mean()+norm[0]*dist_cells*0.2, y_cell_n.mean()+norm[1]*dist_cells*0.2,
                '$F_{\mathregular{%d}}$'%(index_local_cells),va='center',ha='center',color='w', fontsize=16)
        index_local_cells = index_local_cells+1
        lf,=ax.plot(x_face,y_face, color='b',**kwargs)
        # plot the normal vector of the face
        xy_cf=np.array([x_face.mean(), y_face.mean()])
        len_arrow=np.sqrt(np.sum(np.array([x_face.max()-x_face.min(),y_face.max()-y_face.min()])**2))/3
        ax.annotate("", xy=xy_cf+norm*len_arrow/2, xytext=xy_cf-norm*len_arrow/2,arrowprops=dict(arrowstyle="->",color=lf.get_color()))
        norm_orth=[-norm[1],norm[0]] # one orthogonal vector of the norm vector
        ax.text(xy_cf[0]+norm_orth[0]*len_arrow/2, xy_cf[1]+norm_orth[1]*len_arrow/2, '$f_{%d}$'%(index_local_faces), fontsize=16, ha='center', va='center',color='w')
        index_local_faces = index_local_faces+1
    
    return x,y,z,cells_poly,faces,boundaries,owners,neighbours