import pytest
import numpy as np
import torch
import gsoup
from pathlib import Path

def test_type_conversions():
    test_numpy_bool = np.array([1, 0, 0, 1], dtype=bool)
    test_torch_bool = gsoup.to_torch(test_numpy_bool)

    test_float_from_bool = gsoup.to_float(test_numpy_bool)
    assert test_float_from_bool.dtype == np.float32
    test_8bit_from_bool = gsoup.to_8b(test_numpy_bool)
    assert test_8bit_from_bool.dtype == np.uint8
    test_8bit_from_float = gsoup.to_8b(test_float_from_bool)
    assert test_8bit_from_float.dtype == np.uint8
    test_float_from_8bit = gsoup.to_float(test_8bit_from_float)
    assert test_float_from_8bit.dtype == np.float32
    test_float_from_bool = gsoup.to_float(test_torch_bool)
    assert test_float_from_bool.dtype == torch.float32
    test_8bit_from_bool = gsoup.to_8b(test_torch_bool)
    assert test_8bit_from_bool.dtype == torch.uint8
    test_8bit_from_float = gsoup.to_8b(test_float_from_bool)
    assert test_8bit_from_float.dtype == torch.uint8
    test_float_from_8bit = gsoup.to_float(test_8bit_from_float)
    assert test_float_from_8bit.dtype == torch.float32

def test_broadcast_batch():
    R = np.random.randn(1, 3, 3)
    t = np.random.randn(1, 3)
    R, t = gsoup.broadcast_batch(R, t)
    assert R.shape == (1, 3, 3)
    assert t.shape == (1, 3)
    R = np.random.randn(2, 3, 3)
    t = np.random.randn(2, 3)
    R, t = gsoup.broadcast_batch(R, t)
    assert R.shape == (2, 3, 3)
    assert t.shape == (2, 3)
    R = np.random.randn(2, 3, 3)
    t = np.random.randn(1, 3)
    R, t = gsoup.broadcast_batch(R, t)
    assert R.shape == (2, 3, 3)
    assert t.shape == (2, 3)
    R = np.random.randn(1, 3, 3)
    t = np.random.randn(2, 3)
    R, t = gsoup.broadcast_batch(R, t)
    assert R.shape == (2, 3, 3)
    assert t.shape == (2, 3)
    R = np.random.randn(3)
    t = np.random.randn(3)
    R, t = gsoup.broadcast_batch(R, t)
    assert R.shape == (1, 3)
    assert t.shape == (1, 3)
    R = np.random.randn(1)
    t = np.random.randn(1)
    R, t = gsoup.broadcast_batch(R, t)
    assert R.shape == (1, 1)
    assert t.shape == (1, 1)
    R = np.random.randn(3, 3)
    t = np.random.randn(3)
    with pytest.raises(ValueError):
        R, t = gsoup.broadcast_batch(R, t)

def test_transforms():
    R = np.random.randn(2, 3, 3)
    t = np.random.randn(1, 3)
    Rt = gsoup.compose_rt(R, t)
    assert Rt.shape == (2, 3, 4)

    eye = np.array([1, 0, 0], dtype=np.float32)
    at = np.array([0, 0, 0], dtype=np.float32)
    up = np.array([0, 0, 1], dtype=np.float32)
    transform_np = gsoup.look_at_np(eye, at, up)
    transform_torch = gsoup.look_at_torch(gsoup.to_torch(eye), gsoup.to_torch(at), gsoup.to_torch(up))
    assert np.allclose(transform_np, gsoup.to_numpy(transform_torch))
    transform_np_opengl = gsoup.look_at_np(eye, at, up, opengl=True)
    transform_torch_opengl = gsoup.look_at_torch(gsoup.to_torch(eye), gsoup.to_torch(at), gsoup.to_torch(up), opengl=True)
    assert np.allclose(transform_np_opengl, gsoup.to_numpy(transform_torch_opengl))
    normal = torch.tensor([0, 0, 1.0])

    location3d = np.array([0.1, 0.1, 0.1])
    location3d_noise = location3d + np.random.randn(3) * 0.01
    location4d = gsoup.to_hom(location3d)
    location4d_noise = gsoup.to_hom(location3d_noise)
    # sanity on opengl projection
    v2w = gsoup.look_at_np(np.array([1., 0, 0]), location3d, np.array([0, 0, 1.0]), opengl=True)[0]
    w2v = np.linalg.inv(v2w)
    v2c = gsoup.perspective_projection()
    opengl_location = v2c @ w2v @ location4d
    opengl_location = gsoup.homogenize(opengl_location)
    assert np.allclose(opengl_location[:2], np.zeros(2))
    # sanity on opencv projection
    c2w = gsoup.look_at_np(np.array([1., 0, 0]), location3d, np.array([0, 0, 1.0]))[0]
    w2c = np.linalg.inv(c2w)
    K = gsoup.opencv_intrinsics_from_opengl_project(v2c, 1, 1)
    opencv_location = K @ gsoup.to_34(w2c) @ location4d
    opencv_location = gsoup.homogenize(opencv_location)
    assert np.allclose(opencv_location[:2], np.ones(2)*0.5)
    # test opencv / opengl conversion
    opengl_location = v2c @ w2v @ location4d_noise
    opengl_location = gsoup.homogenize(opengl_location)
    opencv_location = K @ gsoup.to_34(w2c) @ location4d_noise
    opencv_location = gsoup.homogenize(opencv_location)
    # x is in same direction, but opengl screen is -1 to 1 (while opencv is 0 to 1)
    assert np.allclose((opencv_location[0] - opengl_location[0]/2), 0.5)
    # y is in opposite direction, but opengl screen is -1 to 1 (while opencv is 0 to 1)
    assert np.allclose((opencv_location[1] + opengl_location[1]/2), 0.5)

def test_rotations():
    qvecs = gsoup.random_qvec(10)
    torch_qvecs = torch.tensor(qvecs)
    rotmats = gsoup.batch_qvec2mat(qvecs)
    assert rotmats.shape == (10, 3, 3)
    rotmats = gsoup.batch_qvec2mat(torch_qvecs)
    assert rotmats.shape == (10, 3, 3)
    new_qvecs = gsoup.batch_mat2qvec(rotmats)
    mask1 = (torch.abs(new_qvecs - torch_qvecs) < 1e-6)
    mask2 = (torch.abs(new_qvecs + torch_qvecs) < 1e-6)
    assert torch.all(mask1 | mask2)
    rotmat = gsoup.qvec2mat(torch_qvecs[0])
    assert rotmat.shape == (3, 3)
    new_qvec = gsoup.mat2qvec(rotmat)
    mask1 = (torch.abs(new_qvec - torch_qvecs[0]) < 1e-6)
    mask2 = (torch.abs(new_qvec + torch_qvecs[0]) < 1e-6)
    assert torch.all(mask1 | mask2)
    normal = torch.tensor([0, 0, 1.0])
    random_vectors = gsoup.random_vectors_on_sphere(10, normal=normal)
    assert random_vectors.shape == (10, 3)
    assert (random_vectors @ normal).all() > 0
    normal = torch.tensor([[0, 0, 1.0]]).repeat(10, 1)
    random_vectors = gsoup.random_vectors_on_sphere(10, normal=normal)
    assert random_vectors.shape == (10, 3)
    assert (random_vectors[:, None, :] @ normal[:, :, None]).all() > 0
    rotx = gsoup.rotx(np.pi/2, degrees=False)
    assert np.allclose(rotx, np.array([[1., 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]))

def test_homogenize():
    x = np.random.rand(100, 2)
    hom_x = gsoup.to_hom(x)
    assert hom_x.shape == (100, 3)
    hom_x = gsoup.to_hom(hom_x)
    assert hom_x.shape == (100, 4)
    dehom_x = gsoup.homogenize(hom_x)
    assert dehom_x.shape == (100, 3)
    dehom_x = gsoup.homogenize(dehom_x, keepdim=True)
    assert dehom_x.shape == (100, 3)
    x = np.random.rand(3)
    hom_x = gsoup.to_hom(x)
    assert hom_x.shape == (4,)
    dehom_x = gsoup.homogenize(hom_x)
    assert dehom_x.shape == (3,)

def test_normalize_vertices():
    v = np.random.rand(100, 3) * 100
    v_normalized = gsoup.normalize_vertices(v)
    assert (v_normalized < 1.0).all()

    v = torch.rand(100, 3) * 100
    v_normalized = gsoup.normalize_vertices(v)
    assert (v_normalized < 1.0).all()

def test_structures():
    v, f = gsoup.structures.cube()
    gsoup.save_obj(v, f, "resource/cube.obj")
    v1, f1 = gsoup.load_obj("resource/cube.obj")
    assert np.allclose(v, v1)
    assert np.allclose(f, f1)
    v, f = gsoup.structures.icosehedron()
    gsoup.save_mesh(v, f, "resource/ico.obj")
    v1, f1 = gsoup.load_obj("resource/ico.obj")
    assert np.allclose(v, v1)
    assert np.allclose(f, f1)
    v, f = gsoup.load_obj("resource/cube.obj")
    assert v.shape[0] == 8
    assert f.shape[0] == 12

def test_image():
    checkboard = gsoup.generate_checkerboard(512, 512, 8)
    gsoup.save_image(checkboard, "resource/checkboard.png")
    lollipop_path = Path("resource/lollipop.png")
    lollipop = gsoup.generate_lollipop_pattern(512, 512, dst=lollipop_path)
    gsoup.save_image(lollipop, lollipop_path)
    lollipop2 = gsoup.load_image(lollipop_path)
    assert np.allclose(lollipop, lollipop2)
    gsoup.save_images(lollipop[None, ...], lollipop_path.parent, file_names=["test_save.png"])
    lollipop_pad = gsoup.pad_image_to_res(lollipop[None, ...], 512, 1024)
    assert lollipop_pad.shape == (1, 512, 1024, 3)
    lollipop_srgb = gsoup.linear_to_srgb(lollipop)
    lollipop_linear = gsoup.srgb_to_linear(lollipop_srgb)
    assert np.allclose(lollipop, lollipop_linear)
    gsoup.generate_concentric_circles(256, 512, dst=Path("resource/circles.png"))
    gsoup.generate_stripe_pattern(256, 512, direction="both", dst=Path("resource/stripe.png"))
    gsoup.generate_dot_pattern(512, 256, dst=Path("resource/dots.png"))
    gray1 = gsoup.generate_gray_gradient(256, 256, grayscale=True, dst=Path("resource/gg_vert.png"))
    assert gray1.shape == (256, 256)
    assert len(np.unique(gray1)) == 10
    assert gray1.max() == 255
    gray2 = gsoup.generate_gray_gradient(50, 800, vertical=False, dst=Path("resource/gg_horiz.png"))
    assert gray2.shape == (50, 800, 3)
    assert gray2.max() == 255
    gray3 = gsoup.generate_gray_gradient(256, 256, bins=-65, dst=Path("resource/gg_bin_min.png"))
    assert gray3.max() == 0
    gray4 = gsoup.generate_gray_gradient(256, 256, bins=300, dst=Path("resource/gg_bin_max.png"))
    assert gray4.max() == 255
    gray5 = gsoup.generate_gray_gradient(1080, 1920, bins=300, dst=Path("resource/gg_highres.png"))
    assert gray5.shape == (1080, 1920, 3)
    assert gray5.max() == 255
    dst = Path("resource/voronoi.png")
    gsoup.generate_voronoi_diagram(512, 512, 1000, dst=dst)
    img = gsoup.load_image(dst)
    assert img.shape == (512, 512, 3)
    assert img.dtype == np.uint8
    img = gsoup.load_image(dst, as_grayscale=True)
    assert img.shape == (512, 512)
    assert img.dtype == np.uint8
    img = gsoup.load_image(dst, to_float=True)
    assert img.shape == (512, 512, 3)
    assert img.dtype == np.float32
    assert (img>=0.0).all()
    assert (img<=1.0).all()
    img = gsoup.load_image(dst, channels_last=False)
    assert img.shape == (3, 512, 512)
    img = gsoup.load_image(dst, to_float=True, as_grayscale=True)
    assert img.shape == (512, 512)
    assert img.dtype == np.float32
    assert (img>=0.0).all()
    assert (img<=1.0).all()
    img = gsoup.load_image(dst, channels_last=False, as_grayscale=True)
    assert img.shape == (512, 512)
    img = gsoup.load_images(dst)
    assert img.shape == (1, 512, 512, 3)
    img = gsoup.load_images(dst, as_grayscale=True)
    assert img.shape == (1, 512, 512)
    img = gsoup.load_images([dst])
    assert img.shape == (1, 512, 512, 3)
    img = gsoup.load_images([dst, dst, dst, dst])
    assert img.shape == (4, 512, 512, 3)
    resized_img = gsoup.resize_images_naive(img, 256, 256, mode="mean")
    assert resized_img.shape == (4, 256, 256, 3)
    resized_img_float = gsoup.resize_images_naive(gsoup.to_float(img), 256, 256, mode="mean")
    assert resized_img_float.shape == (4, 256, 256, 3)
    assert resized_img_float.dtype == np.float32
    resized_img_gray = gsoup.resize_images_naive(img[..., 0:1], 256, 256, mode="mean")
    assert resized_img_gray.shape == (4, 256, 256, 1)
    grid = gsoup.image_grid(resized_img, 2, 2)
    assert grid.shape == (512, 512, 3)
    img = gsoup.load_images([dst, dst, dst, dst], as_grayscale=True)
    assert img.shape == (4, 512, 512)
    img = gsoup.load_images([dst, dst, dst, dst], as_grayscale=True, channels_last=False)
    assert img.shape == (4, 512, 512)
    img = gsoup.load_images([dst, dst, dst, dst], resize_wh=(128, 128))
    assert img.shape == (4, 128, 128, 3)
    img, paths = gsoup.load_images([dst, dst, dst, dst], resize_wh=(128, 256), as_grayscale=True, channels_last=False, return_paths=True, to_float=True, to_torch=True)
    assert len(paths) == 4
    assert img.dtype == torch.float32
    assert img.shape == (4, 256, 128)

def test_video():
    frame_number = 100
    images = np.random.randint(0, 255, (frame_number, 512, 512, 3), dtype=np.uint8)
    # im1 = gsoup.generate_voronoi_diagram(512, 512, 1000)
    # im2 = gsoup.generate_voronoi_diagram(512, 512, 1000)
    # im1s = np.tile(im1[None, ...], (10, 1, 1, 1))
    # im2s = np.tile(im2[None, ...], (10, 1, 1, 1))
    # images = np.vstack([im1s, im2s])
    dst = Path("resource/noise.avi")
    gsoup.save_video(images, dst, fps=10)
    reader = gsoup.VideoReader(dst, h=512, w=512)
    fps = gsoup.FPS()
    for i, frame in enumerate(reader):
        print("{}: {}, fps: {}".format(i, frame.shape, fps()))
        assert np.all(frame == images[i])
    video_frames = gsoup.load_video(dst)
    assert video_frames.shape == (frame_number, 512, 512, 3)
    assert np.all(video_frames == images)
    video_frames_reversed = gsoup.reverse_video(dst)
    assert (video_frames_reversed[-1] == video_frames[0]).all()
    sliced_frames = gsoup.slice_from_video(dst, every_n_frames=2, start_frame=0, end_frame=6)
    assert (sliced_frames == video_frames[:7:2, :, :, :]).all()
    gsoup.video_to_images(dst, Path("resource/noise"))
    discrete_images = gsoup.load_images(Path("resource/noise"))
    assert discrete_images.shape == (frame_number, 512, 512, 3)
    timestamps = gsoup.get_frame_timestamps(dst)
    assert timestamps[0] == 0

def test_procam():
    gc_patterns = gsoup.generate_gray_code(128, 128, 1)
    c2p, p2c = gsoup.pix2pix_correspondence(gc_patterns.shape[2], gc_patterns.shape[1],
                                            1, gc_patterns[..., None].repeat(3, axis=-1),
                                            verbose=False, debug=True, output_dir=Path("resource/pix2pix"))
    desired = gsoup.generate_lollipop_pattern(128, 128)
    desired = gsoup.to_float(desired)
    warp_image = gsoup.warp_image(p2c, desired, output_path=Path("resource/warp.png"))
    assert warp_image.shape == (128, 128, 3)
    assert warp_image.dtype == np.uint8
    assert np.mean(np.abs(gsoup.to_8b(desired) - warp_image)) < 50  # surely an identity corrospondence & warp can't be too bad

def test_sphere_tracer():
    image_size = 512
    device = "cuda:0"
    w2v, v2c = gsoup.create_random_cameras_on_unit_sphere(5, 1.0, opengl=True, device=device)
    ray_origins, ray_directions = gsoup.generate_rays(w2v, v2c[0], image_size, image_size, device=device)
    sdf = gsoup.structures.sphere_sdf(0.25)
    images = []
    for o, d in zip(ray_origins, ray_directions):
        result = gsoup.render(sdf, o.view(-1, 3), d.view(-1, 3))
        images.append(result.view(image_size, image_size, 4))
    images = gsoup.to_np(torch.stack(images))
    images = gsoup.alpha_compose(images)
    gizmo_images = gsoup.draw_gizmo_on_image(images, gsoup.to_np(v2c @ w2v), opengl=True)
    gsoup.save_images(gizmo_images, Path("resource/sphere_trace"))

def test_qem():
    v, f = gsoup.structures.cube()
    v_new, f_new = gsoup.qem(v, f, budget = 4)
    assert f_new.shape[0] == 4