from .util import *
import numpy as np
import cv2

# Reference:
#     1. http://stackoverflow.com/questions/17087446/how-to-calculate-perspective-transform-for-opencv-from-rotation-angles
#     2. http://jepsonsblog.blogspot.tw/2012/11/rotation-in-3d-using-opencvs.html
#     3. https://en.wikipedia.org/wiki/3D_projection#Perspective_projection

class ImageTransformer(object):
    def __init__(self, image_path, height=None, width=None):
        self.image_path = image_path
        self.image = load_image(image_path, height, width)
        self.warped = self.image

        self.height = self.image.shape[0]
        self.width = self.image.shape[1]
        self.num_channels = self.image.shape[2]


    def save(self, img_path):
        cv2.imwrite(img_path, self.warped)


    def skew(self, bg=None, theta=0, phi=0, gamma=0, dx=0, dy=0, dz=0):

        rtheta, rphi, rgamma = get_rad(theta, phi, gamma)

        d = np.sqrt(self.height**2 + self.width**2)
        self.focal = d / (2 * np.sin(rgamma) if np.sin(rgamma) != 0 else 1)
        dz = self.focal
        mat = self.get_M(rtheta, rphi, rgamma, dx, dy, dz)
        if bg is None:
            bg = (0, 0, 0)

        self.warped = cv2.warpPerspective(self.image.copy(), mat, (self.width, self.height), borderValue=bg)
        return self.warped


    def get_M(self, theta, phi, gamma, dx, dy, dz):

        w = self.width
        h = self.height
        f = self.focal

        A1 = np.array([ [1, 0, -w/2],
                        [0, 1, -h/2],
                        [0, 0, 1],
                        [0, 0, 1]])

        RX = np.array([ [1, 0, 0, 0],
                        [0, np.cos(theta), np.sin(theta), 0],
                        [0, -np.sin(theta), np.cos(theta), 0],
                        [0, 0, 0, 1]])

        RY = np.array([ [np.cos(phi), 0, -np.sin(phi), 0],
                        [0, 1, 0, 0],
                        [np.sin(phi), 0, np.cos(phi), 0],
                        [0, 0, 0, 1]])

        RZ = np.array([ [np.cos(gamma), np.sin(gamma), 0, 0],
                        [-np.sin(gamma), np.cos(gamma), 0, 0],
                        [0, 0, 1, 0],
                        [0, 0, 0, 1]])

        R = np.dot(np.dot(RX, RY), RZ)

        T = np.array([  [1, 0, 0, dx],
                        [0, 1, 0, dy],
                        [0, 0, 1, dz],
                        [0, 0, 0, 1]])

        A2 = np.array([ [f, 0, w/2, 0],
                        [0, f, h/2, 0],
                        [0, 0, 1, 0]])

        return np.dot(A2, np.dot(T, np.dot(R, A1)))
