#!python3

import sys, math
from time import time as now
from array import array
import argparse

try:
    from PIL.Image import frombuffer
except:
    print("Please install python package 'pillow' in order to run this")
    sys.exit()

import metalcompute as mc

parser = argparse.ArgumentParser()
parser.add_argument("-width",help="specify width of image",type=int,default=4096)
parser.add_argument("-height",help="specify width of image",type=int,default=4096)
parser.add_argument("-outname",help="output image file name",type=str,default="raymarch.jpg")
parser.add_argument("-dev",help="metal device index",type=int,default=-1)
args = parser.parse_args()

kernel_start = """
#include <metal_stdlib>
using namespace metal;

kernel void test(const device float *uniform [[ buffer(0) ]],
                device uchar4 *out [[ buffer(1) ]],
                uint id [[ thread_position_in_grid ]]) {
    float width = uniform[0];
    float height = uniform[1];
    float3 camera = float3(uniform[2], uniform[3], uniform[4]);
    float3 target = float3(uniform[5], uniform[6], uniform[7]);
    float2 uv = float2((id%int(width))/width, 1.0-(id/int(width))/height);
    uv -= 0.5;
    uv *= 1.0;
    // Projection
    float3 up = float3(0.0,1.0,0.0);
    float3 camdir = normalize(target - camera);
    float3 camh = normalize(cross(camdir, up));
    float3 camv = normalize(cross(camdir, camh));
    float3 raydir = normalize(uv.x*camh-uv.y*camv+camdir);
    float raylen = 0.0;
    float3 ball = target;
    float ball_radius = 0.2;
    float ball_radius_sq = ball_radius*ball_radius;
    float floory = -0.6;
    float balld=0.0, floord=0.0, neard=0.0;
    float3 pos = float3(0.0);
"""

kernel_step = """\
    pos = camera + raylen * raydir;
    balld = length_squared(fract(pos - ball)-.5)-ball_radius_sq;
    floord = pos.y - floory;        
    neard = min(balld, floord);
    raylen += neard;
"""

kernel_end = """\
    pos = camera + raylen * raydir;
    float fog = min(1.0, 4.0 / raylen);
    if (raylen > 100.0) fog = 0.0;
    float3 col = fract(pos) * fog;
    out[id] = uchar4(uchar3(col.xyz*255.),255);
}
"""

camera = [1.0,1.0,-2.0]
target = [0.0,0.9,0.0]
uniforms = array('f',
    [
        args.width,
        args.height,
        camera[0],camera[1],camera[2], 
        target[0],target[1],target[2], 
    ])


dev = mc.Device(args.dev)
print(f"Device:{dev}")
image = dev.buffer(args.height * args.width * 4)

raymarch_steps = 300
kernel = kernel_start + kernel_step * raymarch_steps + kernel_end
render_fn = dev.kernel(kernel).function("test")

start = now()
render_fn(args.width * args.height, uniforms, image)
end = now()

print(f"Render took {end-start:1.6}s")

image_buf = frombuffer("RGBA",(args.width, args.height),data=image)
if args.outname.lower().endswith("jpg"):
    image_buf = image_buf.convert('RGB')
image_buf.save(args.outname)
