"""
    Evaluation-related codes are modified from
    https://github.com/hughw19/NOCS_CVPR2019
"""
import logging

import cv2
import numpy as np


def setup_logger(logger_name, log_file, level=logging.INFO):
    logger = logging.getLogger(logger_name)
    formatter = logging.Formatter("%(asctime)s : %(message)s")
    fileHandler = logging.FileHandler(log_file, mode="a")
    fileHandler.setFormatter(formatter)
    logger.setLevel(level)
    logger.addHandler(fileHandler)
    streamHandler = logging.StreamHandler()
    streamHandler.setFormatter(formatter)
    logger.addHandler(streamHandler)
    return logger


def get_sym_info(c):
    #  sym_info  c0 : face classfication  c1, c2, c3:Three view symmetry, correspond to xy, xz, yz respectively
    # c0: 0 no symmetry 1 axis symmetry 2 two reflection planes 3 unimplemented type
    #  Y axis points upwards, x axis pass through the handle, z axis otherwise
    #
    # for specific defination, see sketch_loss
    if c == "bottle":
        sym = np.array([1, 1, 0, 1], dtype=np.int)
    elif c == "bowl":
        sym = np.array([1, 1, 0, 1], dtype=np.int)
    elif c == "camera":
        sym = np.array([0, 0, 0, 0], dtype=np.int)
    elif c == "can":
        sym = np.array([1, 1, 1, 1], dtype=np.int)
    elif c == "laptop":
        sym = np.array([0, 1, 0, 0], dtype=np.int)
    elif c == "mug":
        sym = np.array(
            [0, 1, 0, 0], dtype=np.int
        )  # for mug, we currently mark it as no symmetry
    else:
        sym = np.array([0, 0, 0, 0], dtype=np.int)
    return sym


def get_mean_shape(c):
    if c == "bottle":
        unitx = 87
        unity = 220
        unitz = 89
    elif c == "bowl":
        unitx = 165
        unity = 80
        unitz = 165
    elif c == "camera":
        unitx = 88
        unity = 128
        unitz = 156
    elif c == "can":
        unitx = 68
        unity = 146
        unitz = 72
    elif c == "laptop":
        unitx = 346
        unity = 200
        unitz = 335
    elif c == "mug":
        unitx = 146
        unity = 83
        unitz = 114
    elif c == "02876657":
        unitx = 324 / 4
        unity = 874 / 4
        unitz = 321 / 4
    elif c == "02880940":
        unitx = 675 / 4
        unity = 271 / 4
        unitz = 675 / 4
    elif c == "02942699":
        unitx = 464 / 4
        unity = 487 / 4
        unitz = 702 / 4
    elif c == "02946921":
        unitx = 450 / 4
        unity = 753 / 4
        unitz = 460 / 4
    elif c == "03642806":
        unitx = 581 / 4
        unity = 445 / 4
        unitz = 672 / 4
    elif c == "03797390":
        unitx = 670 / 4
        unity = 540 / 4
        unitz = 497 / 4
    else:
        unitx = 0
        unity = 0
        unitz = 0
        print("This category is not recorded in my little brain.")
    # scale residual
    return np.array([unitx, unity, unitz])


def load_obj(path_to_file):
    """Load obj file.
    Args:
        path_to_file: path
    Returns:
        vertices: ndarray
        faces: ndarray, index of triangle vertices
    """
    vertices = []
    faces = []
    with open(path_to_file, "r") as f:
        for line in f:
            if line[:2] == "v ":
                vertex = line[2:].strip().split(" ")
                vertex = [float(xyz) for xyz in vertex]
                vertices.append(vertex)
            elif line[0] == "f":
                face = line[1:].replace("//", "/").strip().split(" ")
                face = [int(idx.split("/")[0]) - 1 for idx in face]
                faces.append(face)
            else:
                continue
    vertices = np.asarray(vertices)
    faces = np.asarray(faces)
    return vertices, faces


def create_sphere():
    # 642 verts, 1280 faces,
    verts, faces = load_obj("assets/sphere_mesh_template.obj")
    return verts, faces


def random_point(face_vertices):
    """Sampling point using Barycentric coordiante."""
    r1, r2 = np.random.random(2)
    sqrt_r1 = np.sqrt(r1)
    point = (
        (1 - sqrt_r1) * face_vertices[0, :]
        + sqrt_r1 * (1 - r2) * face_vertices[1, :]
        + sqrt_r1 * r2 * face_vertices[2, :]
    )

    return point


def pairwise_distance(A, B):
    """Compute pairwise distance of two point clouds.point
    Args:
        A: n x 3 numpy array
        B: m x 3 numpy array
    Return:
        C: n x m numpy array
    """
    diff = A[:, :, None] - B[:, :, None].T
    C = np.sqrt(np.sum(diff**2, axis=1))

    return C


def uniform_sample(vertices, faces, n_samples, with_normal=False):
    """Sampling points according to the area of mesh surface."""
    sampled_points = np.zeros((n_samples, 3), dtype=float)
    normals = np.zeros((n_samples, 3), dtype=float)
    faces = vertices[faces]
    vec_cross = np.cross(
        faces[:, 1, :] - faces[:, 0, :], faces[:, 2, :] - faces[:, 0, :]
    )
    face_area = 0.5 * np.linalg.norm(vec_cross, axis=1)
    cum_area = np.cumsum(face_area)
    for i in range(n_samples):
        face_id = np.searchsorted(cum_area, np.random.random() * cum_area[-1])
        sampled_points[i] = random_point(faces[face_id, :, :])
        normals[i] = vec_cross[face_id]
    normals = normals / np.linalg.norm(normals, axis=1, keepdims=True)
    if with_normal:
        sampled_points = np.concatenate((sampled_points, normals), axis=1)
    return sampled_points


def farthest_point_sampling(points, n_samples):
    """Farthest point sampling."""
    selected_pts = np.zeros((n_samples,), dtype=int)
    dist_mat = pairwise_distance(points, points)
    # start from first point
    pt_idx = 0
    dist_to_set = dist_mat[:, pt_idx]
    for i in range(n_samples):
        selected_pts[i] = pt_idx
        dist_to_set = np.minimum(dist_to_set, dist_mat[:, pt_idx])
        pt_idx = np.argmax(dist_to_set)
    return selected_pts


def sample_points_from_mesh(path, n_pts, with_normal=False, fps=False, ratio=2):
    """Uniformly sampling points from mesh model.
    Args:
        path: path to OBJ file.
        n_pts: int, number of points being sampled.
        with_normal: return points with normal, approximated by mesh triangle normal
        fps: whether to use fps for post-processing, default False.
        ratio: int, if use fps, sample ratio*n_pts first, then use fps to sample final output.
    Returns:
        points: n_pts x 3, n_pts x 6 if with_normal = True
    """
    vertices, faces = load_obj(path)
    if fps:
        points = uniform_sample(vertices, faces, ratio * n_pts, with_normal)
        pts_idx = farthest_point_sampling(points[:, :3], n_pts)
        points = points[pts_idx]
    else:
        points = uniform_sample(vertices, faces, n_pts, with_normal)
    return points


def load_depth(depth_path):
    depth = cv2.imread(depth_path, -1)
    """ process depth. """
    if len(depth.shape) == 3:
        # This is encoded depth image, let's convert
        # NOTE: RGB is actually BGR in opencv
        depth16 = depth[:, :, 1] * 256 + depth[:, :, 2]
        depth16 = np.where(depth16 == 32001, 0, depth16)
        depth16 = depth16.astype(np.uint16)
    elif len(depth.shape) == 2 and depth.dtype == "uint16":
        depth16 = depth
    else:
        assert False, "[ Error ]: Unsupported depth type."
    return depth16


def get_bbox(bbox):
    """Compute square image crop window."""
    y1, x1, y2, x2 = bbox
    img_width = 480
    img_length = 640
    window_size = (max(y2 - y1, x2 - x1) // 40 + 1) * 40
    window_size = min(window_size, 440)
    center = [(y1 + y2) // 2, (x1 + x2) // 2]
    rmin = center[0] - int(window_size / 2)
    rmax = center[0] + int(window_size / 2)
    cmin = center[1] - int(window_size / 2)
    cmax = center[1] + int(window_size / 2)
    if rmin < 0:
        delt = -rmin
        rmin = 0
        rmax += delt
    if cmin < 0:
        delt = -cmin
        cmin = 0
        cmax += delt
    if rmax > img_width:
        delt = rmax - img_width
        rmax = img_width
        rmin -= delt
    if cmax > img_length:
        delt = cmax - img_length
        cmax = img_length
        cmin -= delt
    return rmin, rmax, cmin, cmax


def compute_sRT_errors(sRT1, sRT2):
    """
    Args:
        sRT1: [4, 4]. homogeneous affine transformation
        sRT2: [4, 4]. homogeneous affine transformation
    Returns:
        R_error: angle difference in degree,
        T_error: Euclidean distance
        IoU: relative scale error
    """
    try:
        assert np.array_equal(sRT1[3, :], sRT2[3, :])
        assert np.array_equal(sRT1[3, :], np.array([0, 0, 0, 1]))
    except AssertionError:
        print(sRT1[3, :], sRT2[3, :])

    s1 = np.cbrt(np.linalg.det(sRT1[:3, :3]))
    R1 = sRT1[:3, :3] / s1
    T1 = sRT1[:3, 3]
    s2 = np.cbrt(np.linalg.det(sRT2[:3, :3]))
    R2 = sRT2[:3, :3] / s2
    T2 = sRT2[:3, 3]
    R12 = R1 @ R2.transpose()
    R_error = np.arccos(np.clip((np.trace(R12) - 1) / 2, -1.0, 1.0)) * 180 / np.pi
    T_error = np.linalg.norm(T1 - T2)
    # todo wrong!!
    IoU = np.abs(s1 - s2) / s2

    return R_error, T_error, IoU
