# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/core.ipynb (unless otherwise specified).

__all__ = ['extract_patches', 'build_image_pyramid', 'extract_patches_Ms', 'convert_cv2_keypoints',
           'convert_cv2_plus_A_keypoints', 'convert_xyA']

# Cell
from typing import List, Union, Tuple
import numpy as np
import math
from math import sqrt
import cv2
from .laf import *

def extract_patches(kpts: Union[List, np.array],
                    img: np.array,
                    PS: int = 32,
                    mag_factor: float = 10.0,
                    input_format: str = 'cv2') -> List[np.array]:
    """
    Extracts patches given the keypoints in the one of the following formats:
     - cv2: list of cv2 keypoints
     - cv2+A: tuple of (list of cv2 keypoints, Nx2x2 np array)
     - ellipse: Nx5 np array, single row is [x y a b c]
     - xyA: Nx6 np array, single row is [x y a11 a12 a21 a22]
     - LAF: Nx2x3 or 1xNx2x3 np array, single row is [a11 a12 x; a21 a22 y]

    Returns list of patches.
    Upgraded version of
    mag_factor is a scale coefficient. Use 10 for extracting OpenCV SIFT patches, 1.0 for OpenCV ORB patches, etc
    PS is the output patch size in pixels

    Output is list of np.arrays with shape = [ch, PS, PS], where ch is original image #channels
    """
    if input_format == 'cv2':
        Ms, pyr_idxs = convert_cv2_keypoints(kpts, PS, mag_factor)
    elif input_format == 'cv2+A':
        Ms, pyr_idxs = convert_cv2_plus_A_keypoints(kpts[0], kpts[1], PS, mag_factor)
    elif (input_format == 'ellipse') or (input_format == 'xyabc'):
        assert kpts.shape[1] == 5
        Ms, pyr_idxs = convert_ellipse_keypoints(kpts, PS, mag_factor)
    elif input_format == 'xyA':
        assert kpts.shape[1] == 6
        Ms, pyr_idxs = convert_xyA(kpts, PS, mag_factor)
    elif input_format == 'LAF':
        assert (len(kpts.shape) == 3) or (len(kpts.shape) == 4)
        if len(kpts.shape) == 4:
            if kpts.shape[0] == 1:
                kpts = kpts.squeeze(0)
            elif kpts.shape[1] == 1:
                kpts = kpts.squeeze(1)
            else:
                raise ValueError('Bad shape for laf',kpts.shape)
        assert kpts.shape[2] == 3
        assert kpts.shape[1] == 2
        Ms, pyr_idxs = convert_LAFs(kpts, PS, mag_factor)
    else:
        raise ValueError('Unknown input format',input_format)
    return extract_patches_Ms(Ms, img, pyr_idxs, PS)


def build_image_pyramid(img: np.array, min_size: int) -> List[np.array]:
    """
    Builds image pyramid up until min_size pixel size
    """
    img_pyr = [img]
    cur_img = img
    while np.min(cur_img.shape[:2]) > min_size:
        cur_img = cv2.pyrDown(cur_img)
        img_pyr.append(cur_img)
    return img_pyr

def extract_patches_Ms(Ms: List[np.array], img: np.array, pyr_idxs: List[int] = [], PS:int = 32):
    """
    Builds image pyramid and rectifies patches around keypoints
    in the tranformation matrix format
    from the appropriate level of image pyramid,
    removing high freq artifacts. Border mode is set to "replicate",
    so the boundary patches don`t have crazy black borders
    Returns list of patches.
    Upgraded version of
    https://github.com/vbalnt/tfeat/blob/master/tfeat_utils.py
    """
    assert len(Ms) == len(pyr_idxs)
    img_pyr = build_image_pyramid(img, PS//2)
    max_pyr_idx = len(img_pyr) - 1
    patches = []
    for i, M in enumerate(Ms):
        patch = cv2.warpAffine(img_pyr[min(max_pyr_idx, pyr_idxs[i])], M, (PS, PS),
                             flags=cv2.WARP_INVERSE_MAP + \
                             cv2.INTER_LINEAR + cv2.WARP_FILL_OUTLIERS,
                             borderMode=cv2.BORDER_REPLICATE)
        patches.append(patch)
    return patches

def convert_cv2_keypoints(kps: List, PS: int, mag_factor: float):
    """
    Converts OpenCV keypoints into transformation matrix
    and pyramid index to extract from for the patch extraction
    """
    Ms = []
    pyr_idxs = []
    for i, kp in enumerate(kps):
        x,y = kp.pt
        s = kp.size
        a = kp.angle
        s = mag_factor * s / PS
        pyr_idx = max(0, int(math.log(s,2)))
        d_factor = float(math.pow(2.,pyr_idx))
        s_pyr = s / d_factor
        cos = math.cos(a * math.pi / 180.0)
        sin = math.sin(a * math.pi / 180.0)
        M = np.matrix([
            [+s_pyr * cos, -s_pyr * sin, (-s_pyr * cos + s_pyr * sin) * PS / 2.0 + x/d_factor],
            [+s_pyr * sin, +s_pyr * cos, (-s_pyr * sin - s_pyr * cos) * PS / 2.0 + y/d_factor]])
        Ms.append(M)
        pyr_idxs.append(pyr_idx)
    return Ms, pyr_idxs

def convert_cv2_plus_A_keypoints(kps: List, A: np.array,  PS: int, mag_factor: float):
    """
    Converts OpenCV keypoints + A [n x 2 x 2] affine shape
    into transformation matrix
    and pyramid index to extract from for the patch extraction
    """
    Ms = []
    pyr_idxs = []
    for i, kp in enumerate(kps):
        x,y = kp.pt
        s = kp.size
        a = kp.angle
        s = mag_factor * s / PS
        pyr_idx = max(0, int(math.log(s,2)))
        d_factor = float(math.pow(2.,pyr_idx))
        s_pyr = s / d_factor
        cos = math.cos(a * math.pi / 180.0)
        sin = math.sin(a * math.pi / 180.0)
        Ai = A[i]
        RotA = np.matrix([
            [+s_pyr * cos, -s_pyr * sin],
            [+s_pyr * sin, +s_pyr * cos]])
        Ai = np.matmul(RotA,np.matrix(Ai))
        M = np.concatenate([Ai, [
            [(-Ai[0,0] - Ai[0,1]) * PS / 2.0 + x/d_factor],
            [(-Ai[1,0] - Ai[1,1]) * PS / 2.0 + y/d_factor]]], axis = 1)
        Ms.append(M)
        pyr_idxs.append(pyr_idx)
    return Ms, pyr_idxs

def convert_xyA(kps: List,  PS: int, mag_factor: float) -> Tuple[List[np.array], List[int]]:
    """
    Converts n x [x y a11 a12 a21 a22] affine regions
    into transformation matrix
    and pyramid index to extract from for the patch extraction
    """
    Ms = []
    pyr_idxs = []
    for i, kp in enumerate(kps):
        x = kp[0]
        y = kp[1]
        Ai = mag_factor * kp[2:].reshape(2,2) / PS
        s = np.sqrt(np.abs(Ai[0,0]*Ai[1,1]-Ai[0,1]*Ai[1,0]))
        pyr_idx = max(0, int(math.log(s,2)))
        d_factor = float(math.pow(2.,pyr_idx))
        Ai = Ai / d_factor
        M = np.concatenate([Ai, [
            [(-Ai[0,0] - Ai[0,1]) * PS / 2.0 + x/d_factor],
            [(-Ai[1,0] - Ai[1,1]) * PS / 2.0 + y/d_factor]]], axis = 1)
        Ms.append(M)
        pyr_idxs.append(pyr_idx)
    return Ms, pyr_idxs
