"""
transform.py

language: Python
version: 3.7
author: C. Lockhart <chris@lockhartlab.org>
"""

from numba import njit
import numpy as np


def fit(a, b, backend='python'):
    """
    `a` and `b` must be paired.

    Parameters
    ----------
    a : simlib.Trajectory
    b : simlib.Trajectory
    backend

    Returns
    -------

    """

    # Number of structures
    n_structures = a.n_structures
    if n_structures != b.n_structures:
        raise AttributeError('a and b must have the same number of structures')

    # Number of atoms
    n_atoms = a.n_atoms
    if n_atoms != b.n_atoms:
        raise AttributeError('a and b must have the same number of atoms')

    # Number of dimensions
    n_dim = 3

    # Get coordinates
    a_xyz = a.xyz
    b_xyz = b.xyz

    # Centers
    a_xyz_center = np.tile(a_xyz.mean(axis=1), n_atoms).reshape(n_structures, n_atoms, n_dim)
    b_xyz_center = np.tile(b_xyz.mean(axis=1), n_atoms).reshape(n_structures, n_atoms, n_dim)

    # Move structures to center
    # a_xyz = to_origin(a.xyz)
    # b_xyz = to_origin(b.xyz)
    a_xyz = a_xyz - a_xyz_center
    b_xyz = b_xyz - b_xyz_center

    # Transpose b
    b_xyz_transpose = np.transpose(b_xyz, axes=[0, 2, 1])

    # Compute covariance matrix
    covariance_matrix = np.matmul(b_xyz_transpose, a_xyz)

    # Get rotation matrix
    rotation_matrix = _get_rotation_matrix(covariance_matrix)

    # Perform optimal rotation
    b_xyz = np.matmul(b_xyz, rotation_matrix)

    # Move b to a center
    b_xyz = b_xyz + a_xyz_center

    # Return
    return b_xyz


def to_origin(a):
    """

    Parameters
    ----------
    a

    Returns
    -------

    """

    # Compute the center
    # TODO put this in geometry.center ?
    center = a.xyz.mean(axis=1)

    # Transform
    a.xyz = a.xyz - center

    # Return
    return a


@njit
def _get_rotation_matrix(covariance_matrix):
    n_structures = covariance_matrix.shape[0]
    rotation_matrix = []
    for i in range(n_structures):
        u, s, v = np.linalg.svd(covariance_matrix[i, :, :], full_matrices=True)
        s = np.diag(np.array([1., 1., np.sign(np.linalg.det(np.dot(u, v)))]))
        r = np.dot(np.dot(u, s), v)
        rotation_matrix.append(r)
    return rotation_matrix
