from PyFaceGen.Mesh.Mesh import Mesh
from PyFaceGen.Mesh.Utils import get_normals
import numpy as np

# Wavefront loader
def wobj_2_mesh(path):
    
    # Initialise the properties
    vertices = []
    normals = []
    uvs = []
    faces = []
    n_faces = 0
    with open(path,'r') as f:
        
        for line in f:
            
            elements = line.split(' ')
            
            if elements[0] == 'v':
                # Vertex position
                x,y,z = float(elements[1]), float(elements[2]), float(elements[3])
                vertex = np.array([x,y,z])
                vertices.append(vertex)
                
            elif elements[0] == 'vt':
                # UV coordinate
                u,v = float(elements[1]), float(elements[2])
                coord = np.array([u,v])
                uvs.append(coord)
                
            elif elements[0] == 'vn':
                # Normal vector
                x,y,z = float(elements[1]), float(elements[2]), float(elements[3])
                normal = np.array([x,y,z])
                normals.append(normal)
                
            elif elements[0] == 'f':
                # Face
                n_faces += 1
                
                if len(elements) > 4:
                    # Catch trailing spaces
                    if ''in elements:
                        elements.remove('')
                
                # Get the indcies of the face
                a,b,c = elements[1], elements[2], elements[3]
                
                if len(a.split('/')) == 1:
                    # No UV or normal information
                    faces.append(np.array([[int(a)-1,int(b)-1,int(c)-1]]).T)
                    
                elif len(a.split('/')) == 2:
                    # We have uv info
                    a_vert, a_uv = a.split('/')
                    b_vert, b_uv = b.split('/')
                    c_vert, c_uv= c.split('/')
                    
                    faces.append(np.array([[int(a_vert)-1,int(b_vert)-1, int(c_vert)-1],
                                           [int(a_uv)-1,int(b_uv)-1, int(c_uv)-1]]).T)
                    
                    
                else:
                    a_vert, a_uv, a_norm = a.split('/')
                    b_vert, b_uv, b_norm = b.split('/')
                    c_vert, c_uv, c_norm = c.split('/')
                    
                    faces.append(np.array([[int(a_vert)-1,int(b_vert)-1, int(c_vert)-1],
                                           [int(a_uv)-1,int(b_uv)-1, int(c_uv)-1],
                                           [int(a_norm)-1,int(b_norm)-1, int(c_norm)-1]]).T)
                    
    # Set the properties to arrays, for simplicity we will always enforce having normals and uvs, even as placeholders          
    vertices = np.array(vertices)
    faces = np.array(faces)
    if normals == []:
        # Compute normals using 1-hop neighbours
        normals = get_normals(vertices, faces)
        f_normal = np.array(faces[:,:,0]).reshape((-1,3,1)) # Vertex postiton faces with normals in 1-1 correspondance
        faces = np.concatenate((faces, f_normal), axis=2)
      
    # We always now have at least position and normals for faces    
    else:    
        normals = np.array(normals)
        
    if uvs == []:
        # Set all UV coords to 0,0
        uvs = np.zeros((vertices.shape[0], 2))
        f_uvs = np.zeros((n_faces, 3, 3), dtype='uint32')
        f_uvs[:,:,0] = faces[:,:,0]
        f_uvs[:,:,1] = faces[:,:,0]
        f_uvs[:,:,2] = faces[:,:,1]
        faces = f_uvs
        
    else:
        uvs = np.array(uvs)
    
    
    mesh = Mesh(vertices, uvs, normals, faces)
    return mesh
                

def mesh_2_wobj(mesh,path):
    
    ''' Converts a .obj file to the custom mesh class '''
    
    # Get the properties of the mesh
    vertices = mesh.vertices
    uvs = mesh.UVs
    normals = mesh.normals
    faces = mesh.faces
    
    if faces.shape[-1] == 1:
        has_uv = has_norm = False
    elif faces.shape[-1] == 2:
        # Either vertices or normals
        has_uv = True
        has_norm = False
        
    elif faces.shape[-1] == 3:
        has_uv = True
        has_norm = True
        
    
    with open(path, 'w') as f:
        
        # Open a file, write some header information
        
        f.write('# Generated by script \n')
        f.write('# %s vertices, %s faces \n'%(vertices.shape[0], faces.shape[0]))
        
        for vertex in vertices:
            
            f.write('v %.4f %.4f %.4f\n'%(vertex[0], vertex[1], vertex[2]))
        
        if has_uv:
            for uv in uvs:
            
                f.write('vt %.4f %.4f\n'%(uv[0], uv[1]))
            
        if has_norm:
            for normal in normals:
            
                f.write('vn %.4f %.4f %.4f\n'%(normal[0], normal[1], normal[2]))
        
        for face in faces:
            
            # Face is 3 by n_properties
            
            if not has_uv and not has_norm:
            
                f.write('f %i %i %i\n'%(face[0,0]+1, face[1,0]+1, face[2,0]+1))
                
            elif has_uv and not has_norm:
                
                f.write('f %i/%i %i/%i %i/%i \n'%(face[0,0]+1,face[0,1]+1, face[1,0]+1,face[1,1]+1, face[2,0]+1,face[2,1]+1))
            
            elif has_uv and has_norm:
                
                f.write('f %i/%i/%i %i/%i/%i %i/%i/%i \n'%(face[0,0]+1,face[0,1]+1,face[0,2]+1, face[1,0]+1,face[1,1]+1,face[1,2]+1, face[2,0]+1,face[2,1]+1, face[2,2]+1))
                
            else:
                
                # Catch 
                f.write('f %i %i %i\n'%(face[0]+1, face[1]+1, face[2]+1))
