#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import unicode_literals
from __future__ import print_function
from io import open
import numpy as np
from intervul.readpos import VulcanPosMesh
from intervul.general import Submesh
import os
import argparse
import re


parser = argparse.ArgumentParser(
    description='Extrae las fuerzas de un set en particular')
parser.add_argument(
        '-s',
        '--set',
        help='Indica de que set quieres obtener la fuerza',
        default=1,
        type=int)
parser.add_argument(
        '-d',
        '--direction',
        help='Define en que dirección se quiere la fuerza.'
             '"x", "y", "z" o "[1.0, 1.0, 1.0]"',
        default='x')
parser.add_argument(
        '-l',
        '--listofnodes',
        help='Si set es -1 toma esta lista de nodos para calcular las fuerzas',
        type=argparse.FileType('r'))
parser.add_argument(
        '-o',
        '--fileOutput',
        help='Indica el nombre del archivo de salida (default: forces.txt)',
        default="forces.txt",
        )
parser.add_argument(
        '--printDisp',
        help="Si agrega esta opcion imprime el desplazamiento",
        action = 'store_true'
        )
parser.add_argument('fileIn', help="Archivo de dato")

args = parser.parse_args()

filename = args.fileIn
fileout = args.fileOutput
# el set está en forma de base 0 (parte de 0)
setForce = args.set - 1
direction = args.direction
nodesfile = args.listofnodes
if args.set < 1:
    nodeList = True
    nodes = []
    for line in nodesfile:
        nodes.append(int(line))
    nodes = np.array(nodes) - 1
else:
    nodeList = False


if direction.startswith('['):
    # Separa la lista y normaliza el vector
    direction = re.split(' +|, *', direction[1:-1])
    direction = [float(x) for x in direction]
    norm = np.linalg.norm(direction)
    direction = [x/norm for x in direction]

data = VulcanPosMesh(filename, VulcanPosMesh.MECHANICAL)
mesh = data.mesh
if not nodeList:
    submesh = Submesh(setNum=setForce)
    submesh.getSet(mesh=mesh)

print(fileout)
firstPass = True
with open(fileout, 'w', encoding='utf-8') as f:
    for istep, result in data:
        if firstPass:
            f.write("# Reaction force of case\n")
            f.write("# " + result['titleResult'].decode('utf-8').strip() + "\n")
            f.write("# filename: " + filename + "\n")
            if nodeList:
                f.write("# file of nodes: " + nodesfile.name + "\n")
            else:
                f.write("# set: " + str(setForce + 1) + "\n")
            f.write("# direction: " + str(direction) + "\n")
            f.write("# Time Displacement Force\n")
            if nodeList:
                print("Extracting forces for filename " + filename + " in list "
                      "of nodes " + nodesfile.name)
            else:
                print("Extracting forces for filename " + filename + " in set " +
                      str(setForce + 1))
            firstPass = False

        if nodeList:
            reactionsPerNode = result['reaction'][nodes,:]
            displacementPerNode = result['displacement'][nodes,:]
        else:
            subresult = result.updateBySplitMesh(mesh=submesh, results=['reaction','displacement'])
            reactionsPerNode = subresult['reaction']
            displacementPerNode = subresult['displacement']
        reaction = 0
        if direction == 'x':
            reaction = reactionsPerNode[:, 0].sum()
            displacement = displacementPerNode[:, 0].sum()/float(displacementPerNode.shape[0])
        elif direction == 'y':
            reaction = reactionsPerNode[:, 1].sum()
            displacement = displacementPerNode[:, 1].sum()/float(displacementPerNode.shape[0])
        elif direction == 'z':
            reaction = reactionsPerNode[:, 2].sum()
            displacement = displacementPerNode[:, 2].sum()/float(displacementPerNode.shape[0])
        else:
            reaction = np.einsum("ij,j -> ", reactionsPerNode, direction)
            displacement = np.einsum("ij,j -> ", displacementPerNode, direction)/float(displacementPerNode.shape[0])
        if args.printDisp:
            text = "{:8.3f} {:14.6E} {:14.6E}".format(result['TimeValue'], displacement, reaction)
        else:
            text = "{:8.3f} {:14.6E}".format(result['TimeValue'], reaction)
        print(text)
        f.write(text + "\n")
